xref: /llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (revision 4f4e2abb1a5ff1225d32410fd02b732d077aa056)
1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements mlir::applyPatternsGreedily.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14 
15 #include "mlir/Config/mlir-config.h"
16 #include "mlir/IR/Action.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Verifier.h"
19 #include "mlir/Interfaces/SideEffectInterfaces.h"
20 #include "mlir/Rewrite/PatternApplicator.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/RegionUtils.h"
23 #include "llvm/ADT/BitVector.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/ScopeExit.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/ScopedPrinter.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
32 #include <random>
33 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
34 
35 using namespace mlir;
36 
37 #define DEBUG_TYPE "greedy-rewriter"
38 
39 namespace {
40 
41 //===----------------------------------------------------------------------===//
42 // Debugging Infrastructure
43 //===----------------------------------------------------------------------===//
44 
45 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
46 /// A helper struct that performs various "expensive checks" to detect broken
47 /// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
48 /// broken if:
49 /// * IR does not verify after pattern application / folding.
50 /// * Pattern returns "failure" but the IR has changed.
51 /// * Pattern returns "success" but the IR has not changed.
52 ///
53 /// This struct stores finger prints of ops to determine whether the IR has
54 /// changed or not.
55 struct ExpensiveChecks : public RewriterBase::ForwardingListener {
56   ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
57       : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
58 
59   /// Compute finger prints of the given op and its nested ops.
60   void computeFingerPrints(Operation *topLevel) {
61     this->topLevel = topLevel;
62     this->topLevelFingerPrint.emplace(topLevel);
63     topLevel->walk([&](Operation *op) {
64       fingerprints.try_emplace(op, op, /*includeNested=*/false);
65     });
66   }
67 
68   /// Clear all finger prints.
69   void clear() {
70     topLevel = nullptr;
71     topLevelFingerPrint.reset();
72     fingerprints.clear();
73   }
74 
75   void notifyRewriteSuccess() {
76     if (!topLevel)
77       return;
78 
79     // Make sure that the IR still verifies.
80     if (failed(verify(topLevel)))
81       llvm::report_fatal_error("IR failed to verify after pattern application");
82 
83     // Pattern application success => IR must have changed.
84     OperationFingerPrint afterFingerPrint(topLevel);
85     if (*topLevelFingerPrint == afterFingerPrint) {
86       // Note: Run "mlir-opt -debug" to see which pattern is broken.
87       llvm::report_fatal_error(
88           "pattern returned success but IR did not change");
89     }
90     for (const auto &it : fingerprints) {
91       // Skip top-level op, its finger print is never invalidated.
92       if (it.first == topLevel)
93         continue;
94       // Note: Finger print computation may crash when an op was erased
95       // without notifying the rewriter. (Run with ASAN to see where the op was
96       // erased; the op was probably erased directly, bypassing the rewriter
97       // API.) Finger print computation does may not crash if a new op was
98       // created at the same memory location. (But then the finger print should
99       // have changed.)
100       if (it.second !=
101           OperationFingerPrint(it.first, /*includeNested=*/false)) {
102         // Note: Run "mlir-opt -debug" to see which pattern is broken.
103         llvm::report_fatal_error("operation finger print changed");
104       }
105     }
106   }
107 
108   void notifyRewriteFailure() {
109     if (!topLevel)
110       return;
111 
112     // Pattern application failure => IR must not have changed.
113     OperationFingerPrint afterFingerPrint(topLevel);
114     if (*topLevelFingerPrint != afterFingerPrint) {
115       // Note: Run "mlir-opt -debug" to see which pattern is broken.
116       llvm::report_fatal_error("pattern returned failure but IR did change");
117     }
118   }
119 
120   void notifyFoldingSuccess() {
121     if (!topLevel)
122       return;
123 
124     // Make sure that the IR still verifies.
125     if (failed(verify(topLevel)))
126       llvm::report_fatal_error("IR failed to verify after folding");
127   }
128 
129 protected:
130   /// Invalidate the finger print of the given op, i.e., remove it from the map.
131   void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
132 
133   void notifyBlockErased(Block *block) override {
134     RewriterBase::ForwardingListener::notifyBlockErased(block);
135 
136     // The block structure (number of blocks, types of block arguments, etc.)
137     // is part of the fingerprint of the parent op.
138     // TODO: The parent op fingerprint should also be invalidated when modifying
139     // the block arguments of a block, but we do not have a
140     // `notifyBlockModified` callback yet.
141     invalidateFingerPrint(block->getParentOp());
142   }
143 
144   void notifyOperationInserted(Operation *op,
145                                OpBuilder::InsertPoint previous) override {
146     RewriterBase::ForwardingListener::notifyOperationInserted(op, previous);
147     invalidateFingerPrint(op->getParentOp());
148   }
149 
150   void notifyOperationModified(Operation *op) override {
151     RewriterBase::ForwardingListener::notifyOperationModified(op);
152     invalidateFingerPrint(op);
153   }
154 
155   void notifyOperationErased(Operation *op) override {
156     RewriterBase::ForwardingListener::notifyOperationErased(op);
157     op->walk([this](Operation *op) { invalidateFingerPrint(op); });
158   }
159 
160   /// Operation finger prints to detect invalid pattern API usage. IR is checked
161   /// against these finger prints after pattern application to detect cases
162   /// where IR was modified directly, bypassing the rewriter API.
163   DenseMap<Operation *, OperationFingerPrint> fingerprints;
164 
165   /// Top-level operation of the current greedy rewrite.
166   Operation *topLevel = nullptr;
167 
168   /// Finger print of the top-level operation.
169   std::optional<OperationFingerPrint> topLevelFingerPrint;
170 };
171 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
172 
173 #ifndef NDEBUG
174 static Operation *getDumpRootOp(Operation *op) {
175   // Dump the parent op so that materialized constants are visible. If the op
176   // is a top-level op, dump it directly.
177   if (Operation *parentOp = op->getParentOp())
178     return parentOp;
179   return op;
180 }
181 static void logSuccessfulFolding(Operation *op) {
182   llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
183   op->dump();
184   llvm::dbgs() << "\n\n";
185 }
186 #endif // NDEBUG
187 
188 //===----------------------------------------------------------------------===//
189 // Worklist
190 //===----------------------------------------------------------------------===//
191 
192 /// A LIFO worklist of operations with efficient removal and set semantics.
193 ///
194 /// This class maintains a vector of operations and a mapping of operations to
195 /// positions in the vector, so that operations can be removed efficiently at
196 /// random. When an operation is removed, it is replaced with nullptr. Such
197 /// nullptr are skipped when pop'ing elements.
198 class Worklist {
199 public:
200   Worklist();
201 
202   /// Clear the worklist.
203   void clear();
204 
205   /// Return whether the worklist is empty.
206   bool empty() const;
207 
208   /// Push an operation to the end of the worklist, unless the operation is
209   /// already on the worklist.
210   void push(Operation *op);
211 
212   /// Pop the an operation from the end of the worklist. Only allowed on
213   /// non-empty worklists.
214   Operation *pop();
215 
216   /// Remove an operation from the worklist.
217   void remove(Operation *op);
218 
219   /// Reverse the worklist.
220   void reverse();
221 
222 protected:
223   /// The worklist of operations.
224   std::vector<Operation *> list;
225 
226   /// A mapping of operations to positions in `list`.
227   DenseMap<Operation *, unsigned> map;
228 };
229 
230 Worklist::Worklist() { list.reserve(64); }
231 
232 void Worklist::clear() {
233   list.clear();
234   map.clear();
235 }
236 
237 bool Worklist::empty() const {
238   // Skip all nullptr.
239   return !llvm::any_of(list,
240                        [](Operation *op) { return static_cast<bool>(op); });
241 }
242 
243 void Worklist::push(Operation *op) {
244   assert(op && "cannot push nullptr to worklist");
245   // Check to see if the worklist already contains this op.
246   if (!map.insert({op, list.size()}).second)
247     return;
248   list.push_back(op);
249 }
250 
251 Operation *Worklist::pop() {
252   assert(!empty() && "cannot pop from empty worklist");
253   // Skip and remove all trailing nullptr.
254   while (!list.back())
255     list.pop_back();
256   Operation *op = list.back();
257   list.pop_back();
258   map.erase(op);
259   // Cleanup: Remove all trailing nullptr.
260   while (!list.empty() && !list.back())
261     list.pop_back();
262   return op;
263 }
264 
265 void Worklist::remove(Operation *op) {
266   assert(op && "cannot remove nullptr from worklist");
267   auto it = map.find(op);
268   if (it != map.end()) {
269     assert(list[it->second] == op && "malformed worklist data structure");
270     list[it->second] = nullptr;
271     map.erase(it);
272   }
273 }
274 
275 void Worklist::reverse() {
276   std::reverse(list.begin(), list.end());
277   for (size_t i = 0, e = list.size(); i != e; ++i)
278     map[list[i]] = i;
279 }
280 
281 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
282 /// A worklist that pops elements at a random position. This worklist is for
283 /// testing/debugging purposes only. It can be used to ensure that lowering
284 /// pipelines work correctly regardless of the order in which ops are processed
285 /// by the GreedyPatternRewriteDriver.
286 class RandomizedWorklist : public Worklist {
287 public:
288   RandomizedWorklist() : Worklist() {
289     generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
290   }
291 
292   /// Pop a random non-empty op from the worklist.
293   Operation *pop() {
294     Operation *op = nullptr;
295     do {
296       assert(!list.empty() && "cannot pop from empty worklist");
297       int64_t pos = generator() % list.size();
298       op = list[pos];
299       list.erase(list.begin() + pos);
300       for (int64_t i = pos, e = list.size(); i < e; ++i)
301         map[list[i]] = i;
302       map.erase(op);
303     } while (!op);
304     return op;
305   }
306 
307 private:
308   std::minstd_rand0 generator;
309 };
310 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
311 
312 //===----------------------------------------------------------------------===//
313 // GreedyPatternRewriteDriver
314 //===----------------------------------------------------------------------===//
315 
316 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
317 /// applies the locally optimal patterns.
318 ///
319 /// This abstract class manages the worklist and contains helper methods for
320 /// rewriting ops on the worklist. Derived classes specify how ops are added
321 /// to the worklist in the beginning.
322 class GreedyPatternRewriteDriver : public RewriterBase::Listener {
323 protected:
324   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
325                                       const FrozenRewritePatternSet &patterns,
326                                       const GreedyRewriteConfig &config);
327 
328   /// Add the given operation to the worklist.
329   void addSingleOpToWorklist(Operation *op);
330 
331   /// Add the given operation and its ancestors to the worklist.
332   void addToWorklist(Operation *op);
333 
334   /// Notify the driver that the specified operation may have been modified
335   /// in-place. The operation is added to the worklist.
336   void notifyOperationModified(Operation *op) override;
337 
338   /// Notify the driver that the specified operation was inserted. Update the
339   /// worklist as needed: The operation is enqueued depending on scope and
340   /// strict mode.
341   void notifyOperationInserted(Operation *op,
342                                OpBuilder::InsertPoint previous) override;
343 
344   /// Notify the driver that the specified operation was removed. Update the
345   /// worklist as needed: The operation and its children are removed from the
346   /// worklist.
347   void notifyOperationErased(Operation *op) override;
348 
349   /// Notify the driver that the specified operation was replaced. Update the
350   /// worklist as needed: New users are added enqueued.
351   void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
352 
353   /// Process ops until the worklist is empty or `config.maxNumRewrites` is
354   /// reached. Return `true` if any IR was changed.
355   bool processWorklist();
356 
357   /// The pattern rewriter that is used for making IR modifications and is
358   /// passed to rewrite patterns.
359   PatternRewriter rewriter;
360 
361   /// The worklist for this transformation keeps track of the operations that
362   /// need to be (re)visited.
363 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
364   RandomizedWorklist worklist;
365 #else
366   Worklist worklist;
367 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
368 
369   /// Configuration information for how to simplify.
370   const GreedyRewriteConfig config;
371 
372   /// The list of ops we are restricting our rewrites to. These include the
373   /// supplied set of ops as well as new ops created while rewriting those ops
374   /// depending on `strictMode`. This set is not maintained when
375   /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
376   llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
377 
378 private:
379   /// Look over the provided operands for any defining operations that should
380   /// be re-added to the worklist. This function should be called when an
381   /// operation is modified or removed, as it may trigger further
382   /// simplifications.
383   void addOperandsToWorklist(Operation *op);
384 
385   /// Notify the driver that the given block was inserted.
386   void notifyBlockInserted(Block *block, Region *previous,
387                            Region::iterator previousIt) override;
388 
389   /// Notify the driver that the given block is about to be removed.
390   void notifyBlockErased(Block *block) override;
391 
392   /// For debugging only: Notify the driver of a pattern match failure.
393   void
394   notifyMatchFailure(Location loc,
395                      function_ref<void(Diagnostic &)> reasonCallback) override;
396 
397 #ifndef NDEBUG
398   /// A logger used to emit information during the application process.
399   llvm::ScopedPrinter logger{llvm::dbgs()};
400 #endif
401 
402   /// The low-level pattern applicator.
403   PatternApplicator matcher;
404 
405 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
406   ExpensiveChecks expensiveChecks;
407 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
408 };
409 } // namespace
410 
411 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
412     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
413     const GreedyRewriteConfig &config)
414     : rewriter(ctx), config(config), matcher(patterns)
415 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
416       // clang-format off
417       , expensiveChecks(
418           /*driver=*/this,
419           /*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr)
420 // clang-format on
421 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
422 {
423   // Apply a simple cost model based solely on pattern benefit.
424   matcher.applyDefaultCostModel();
425 
426   // Set up listener.
427 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
428   // Send IR notifications to the debug handler. This handler will then forward
429   // all notifications to this GreedyPatternRewriteDriver.
430   rewriter.setListener(&expensiveChecks);
431 #else
432   rewriter.setListener(this);
433 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
434 }
435 
436 bool GreedyPatternRewriteDriver::processWorklist() {
437 #ifndef NDEBUG
438   const char *logLineComment =
439       "//===-------------------------------------------===//\n";
440 
441   /// A utility function to log a process result for the given reason.
442   auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
443     logger.unindent();
444     logger.startLine() << "} -> " << result;
445     if (!msg.isTriviallyEmpty())
446       logger.getOStream() << " : " << msg;
447     logger.getOStream() << "\n";
448   };
449   auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
450     logResult(result, msg);
451     logger.startLine() << logLineComment;
452   };
453 #endif
454 
455   bool changed = false;
456   int64_t numRewrites = 0;
457   while (!worklist.empty() &&
458          (numRewrites < config.maxNumRewrites ||
459           config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
460     auto *op = worklist.pop();
461 
462     LLVM_DEBUG({
463       logger.getOStream() << "\n";
464       logger.startLine() << logLineComment;
465       logger.startLine() << "Processing operation : '" << op->getName() << "'("
466                          << op << ") {\n";
467       logger.indent();
468 
469       // If the operation has no regions, just print it here.
470       if (op->getNumRegions() == 0) {
471         op->print(
472             logger.startLine(),
473             OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
474         logger.getOStream() << "\n\n";
475       }
476     });
477 
478     // If the operation is trivially dead - remove it.
479     if (isOpTriviallyDead(op)) {
480       rewriter.eraseOp(op);
481       changed = true;
482 
483       LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
484       continue;
485     }
486 
487     // Try to fold this op. Do not fold constant ops. That would lead to an
488     // infinite folding loop, as every constant op would be folded to an
489     // Attribute and then immediately be rematerialized as a constant op, which
490     // is then put on the worklist.
491     if (config.fold && !op->hasTrait<OpTrait::ConstantLike>()) {
492       SmallVector<OpFoldResult> foldResults;
493       if (succeeded(op->fold(foldResults))) {
494         LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
495 #ifndef NDEBUG
496         Operation *dumpRootOp = getDumpRootOp(op);
497 #endif // NDEBUG
498         if (foldResults.empty()) {
499           // Op was modified in-place.
500           notifyOperationModified(op);
501           changed = true;
502           LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
503 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
504           expensiveChecks.notifyFoldingSuccess();
505 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
506           continue;
507         }
508 
509         // Op results can be replaced with `foldResults`.
510         assert(foldResults.size() == op->getNumResults() &&
511                "folder produced incorrect number of results");
512         OpBuilder::InsertionGuard g(rewriter);
513         rewriter.setInsertionPoint(op);
514         SmallVector<Value> replacements;
515         bool materializationSucceeded = true;
516         for (auto [ofr, resultType] :
517              llvm::zip_equal(foldResults, op->getResultTypes())) {
518           if (auto value = ofr.dyn_cast<Value>()) {
519             assert(value.getType() == resultType &&
520                    "folder produced value of incorrect type");
521             replacements.push_back(value);
522             continue;
523           }
524           // Materialize Attributes as SSA values.
525           Operation *constOp = op->getDialect()->materializeConstant(
526               rewriter, cast<Attribute>(ofr), resultType, op->getLoc());
527 
528           if (!constOp) {
529             // If materialization fails, cleanup any operations generated for
530             // the previous results.
531             llvm::SmallDenseSet<Operation *> replacementOps;
532             for (Value replacement : replacements) {
533               assert(replacement.use_empty() &&
534                      "folder reused existing op for one result but constant "
535                      "materialization failed for another result");
536               replacementOps.insert(replacement.getDefiningOp());
537             }
538             for (Operation *op : replacementOps) {
539               rewriter.eraseOp(op);
540             }
541 
542             materializationSucceeded = false;
543             break;
544           }
545 
546           assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
547                  "materializeConstant produced op that is not a ConstantLike");
548           assert(constOp->getResultTypes()[0] == resultType &&
549                  "materializeConstant produced incorrect result type");
550           replacements.push_back(constOp->getResult(0));
551         }
552 
553         if (materializationSucceeded) {
554           rewriter.replaceOp(op, replacements);
555           changed = true;
556           LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
557 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
558           expensiveChecks.notifyFoldingSuccess();
559 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
560           continue;
561         }
562       }
563     }
564 
565     // Try to match one of the patterns. The rewriter is automatically
566     // notified of any necessary changes, so there is nothing else to do
567     // here.
568     auto canApplyCallback = [&](const Pattern &pattern) {
569       LLVM_DEBUG({
570         logger.getOStream() << "\n";
571         logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
572                            << op->getName() << " -> (";
573         llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
574         logger.getOStream() << ")' {\n";
575         logger.indent();
576       });
577       if (config.listener)
578         config.listener->notifyPatternBegin(pattern, op);
579       return true;
580     };
581     function_ref<bool(const Pattern &)> canApply = canApplyCallback;
582     auto onFailureCallback = [&](const Pattern &pattern) {
583       LLVM_DEBUG(logResult("failure", "pattern failed to match"));
584       if (config.listener)
585         config.listener->notifyPatternEnd(pattern, failure());
586     };
587     function_ref<void(const Pattern &)> onFailure = onFailureCallback;
588     auto onSuccessCallback = [&](const Pattern &pattern) {
589       LLVM_DEBUG(logResult("success", "pattern applied successfully"));
590       if (config.listener)
591         config.listener->notifyPatternEnd(pattern, success());
592       return success();
593     };
594     function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
595 
596 #ifdef NDEBUG
597     // Optimization: PatternApplicator callbacks are not needed when running in
598     // optimized mode and without a listener.
599     if (!config.listener) {
600       canApply = nullptr;
601       onFailure = nullptr;
602       onSuccess = nullptr;
603     }
604 #endif // NDEBUG
605 
606 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
607     if (config.scope) {
608       expensiveChecks.computeFingerPrints(config.scope->getParentOp());
609     }
610     auto clearFingerprints =
611         llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
612 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
613 
614     LogicalResult matchResult =
615         matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
616 
617     if (succeeded(matchResult)) {
618       LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
619 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
620       expensiveChecks.notifyRewriteSuccess();
621 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
622       changed = true;
623       ++numRewrites;
624     } else {
625       LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
626 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
627       expensiveChecks.notifyRewriteFailure();
628 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
629     }
630   }
631 
632   return changed;
633 }
634 
635 void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
636   assert(op && "expected valid op");
637   // Gather potential ancestors while looking for a "scope" parent region.
638   SmallVector<Operation *, 8> ancestors;
639   Region *region = nullptr;
640   do {
641     ancestors.push_back(op);
642     region = op->getParentRegion();
643     if (config.scope == region) {
644       // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
645       for (Operation *op : ancestors)
646         addSingleOpToWorklist(op);
647       return;
648     }
649     if (region == nullptr)
650       return;
651   } while ((op = region->getParentOp()));
652 }
653 
654 void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
655   if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
656       strictModeFilteredOps.contains(op))
657     worklist.push(op);
658 }
659 
660 void GreedyPatternRewriteDriver::notifyBlockInserted(
661     Block *block, Region *previous, Region::iterator previousIt) {
662   if (config.listener)
663     config.listener->notifyBlockInserted(block, previous, previousIt);
664 }
665 
666 void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
667   if (config.listener)
668     config.listener->notifyBlockErased(block);
669 }
670 
671 void GreedyPatternRewriteDriver::notifyOperationInserted(
672     Operation *op, OpBuilder::InsertPoint previous) {
673   LLVM_DEBUG({
674     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
675                        << ")\n";
676   });
677   if (config.listener)
678     config.listener->notifyOperationInserted(op, previous);
679   if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
680     strictModeFilteredOps.insert(op);
681   addToWorklist(op);
682 }
683 
684 void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
685   LLVM_DEBUG({
686     logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
687                        << ")\n";
688   });
689   if (config.listener)
690     config.listener->notifyOperationModified(op);
691   addToWorklist(op);
692 }
693 
694 void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) {
695   for (Value operand : op->getOperands()) {
696     // If this operand currently has at most 2 users, add its defining op to the
697     // worklist. Indeed, after the op is deleted, then the operand will have at
698     // most 1 user left. If it has 0 users left, it can be deleted too,
699     // and if it has 1 user left, there may be further canonicalization
700     // opportunities.
701     if (!operand)
702       continue;
703 
704     auto *defOp = operand.getDefiningOp();
705     if (!defOp)
706       continue;
707 
708     Operation *otherUser = nullptr;
709     bool hasMoreThanTwoUses = false;
710     for (auto user : operand.getUsers()) {
711       if (user == op || user == otherUser)
712         continue;
713       if (!otherUser) {
714         otherUser = user;
715         continue;
716       }
717       hasMoreThanTwoUses = true;
718       break;
719     }
720     if (hasMoreThanTwoUses)
721       continue;
722 
723     addToWorklist(defOp);
724   }
725 }
726 
727 void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
728   LLVM_DEBUG({
729     logger.startLine() << "** Erase   : '" << op->getName() << "'(" << op
730                        << ")\n";
731   });
732 
733 #ifndef NDEBUG
734   // Only ops that are within the configured scope are added to the worklist of
735   // the greedy pattern rewriter. Moreover, the parent op of the scope region is
736   // the part of the IR that is taken into account for the "expensive checks".
737   // A greedy pattern rewrite is not allowed to erase the parent op of the scope
738   // region, as that would break the worklist handling and the expensive checks.
739   if (config.scope && config.scope->getParentOp() == op)
740     llvm_unreachable(
741         "scope region must not be erased during greedy pattern rewrite");
742 #endif // NDEBUG
743 
744   if (config.listener)
745     config.listener->notifyOperationErased(op);
746 
747   addOperandsToWorklist(op);
748   worklist.remove(op);
749 
750   if (config.strictMode != GreedyRewriteStrictness::AnyOp)
751     strictModeFilteredOps.erase(op);
752 }
753 
754 void GreedyPatternRewriteDriver::notifyOperationReplaced(
755     Operation *op, ValueRange replacement) {
756   LLVM_DEBUG({
757     logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
758                        << ")\n";
759   });
760   if (config.listener)
761     config.listener->notifyOperationReplaced(op, replacement);
762 }
763 
764 void GreedyPatternRewriteDriver::notifyMatchFailure(
765     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
766   LLVM_DEBUG({
767     Diagnostic diag(loc, DiagnosticSeverity::Remark);
768     reasonCallback(diag);
769     logger.startLine() << "** Match Failure : " << diag.str() << "\n";
770   });
771   if (config.listener)
772     config.listener->notifyMatchFailure(loc, reasonCallback);
773 }
774 
775 //===----------------------------------------------------------------------===//
776 // RegionPatternRewriteDriver
777 //===----------------------------------------------------------------------===//
778 
779 namespace {
780 /// This driver simplfies all ops in a region.
781 class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
782 public:
783   explicit RegionPatternRewriteDriver(MLIRContext *ctx,
784                                       const FrozenRewritePatternSet &patterns,
785                                       const GreedyRewriteConfig &config,
786                                       Region &regions);
787 
788   /// Simplify ops inside `region` and simplify the region itself. Return
789   /// success if the transformation converged.
790   LogicalResult simplify(bool *changed) &&;
791 
792 private:
793   /// The region that is simplified.
794   Region &region;
795 };
796 } // namespace
797 
798 RegionPatternRewriteDriver::RegionPatternRewriteDriver(
799     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
800     const GreedyRewriteConfig &config, Region &region)
801     : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
802   // Populate strict mode ops.
803   if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
804     region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
805   }
806 }
807 
808 namespace {
809 class GreedyPatternRewriteIteration
810     : public tracing::ActionImpl<GreedyPatternRewriteIteration> {
811 public:
812   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration)
813   GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration)
814       : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
815         iteration(iteration) {}
816   static constexpr StringLiteral tag = "GreedyPatternRewriteIteration";
817   void print(raw_ostream &os) const override {
818     os << "GreedyPatternRewriteIteration(" << iteration << ")";
819   }
820 
821 private:
822   int64_t iteration = 0;
823 };
824 } // namespace
825 
826 LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
827   bool continueRewrites = false;
828   int64_t iteration = 0;
829   MLIRContext *ctx = rewriter.getContext();
830   do {
831     // Check if the iteration limit was reached.
832     if (++iteration > config.maxIterations &&
833         config.maxIterations != GreedyRewriteConfig::kNoLimit)
834       break;
835 
836     // New iteration: start with an empty worklist.
837     worklist.clear();
838 
839     // `OperationFolder` CSE's constant ops (and may move them into parents
840     // regions to enable more aggressive CSE'ing).
841     OperationFolder folder(ctx, this);
842     auto insertKnownConstant = [&](Operation *op) {
843       // Check for existing constants when populating the worklist. This avoids
844       // accidentally reversing the constant order during processing.
845       Attribute constValue;
846       if (matchPattern(op, m_Constant(&constValue)))
847         if (!folder.insertKnownConstant(op, constValue))
848           return true;
849       return false;
850     };
851 
852     if (!config.useTopDownTraversal) {
853       // Add operations to the worklist in postorder.
854       region.walk([&](Operation *op) {
855         if (!config.cseConstants || !insertKnownConstant(op))
856           addToWorklist(op);
857       });
858     } else {
859       // Add all nested operations to the worklist in preorder.
860       region.walk<WalkOrder::PreOrder>([&](Operation *op) {
861         if (!config.cseConstants || !insertKnownConstant(op)) {
862           addToWorklist(op);
863           return WalkResult::advance();
864         }
865         return WalkResult::skip();
866       });
867 
868       // Reverse the list so our pop-back loop processes them in-order.
869       worklist.reverse();
870     }
871 
872     ctx->executeAction<GreedyPatternRewriteIteration>(
873         [&] {
874           continueRewrites = processWorklist();
875 
876           // After applying patterns, make sure that the CFG of each of the
877           // regions is kept up to date.
878           if (config.enableRegionSimplification !=
879               GreedySimplifyRegionLevel::Disabled) {
880             continueRewrites |= succeeded(simplifyRegions(
881                 rewriter, region,
882                 /*mergeBlocks=*/config.enableRegionSimplification ==
883                     GreedySimplifyRegionLevel::Aggressive));
884           }
885         },
886         {&region}, iteration);
887   } while (continueRewrites);
888 
889   if (changed)
890     *changed = iteration > 1;
891 
892   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
893   return success(!continueRewrites);
894 }
895 
896 LogicalResult
897 mlir::applyPatternsGreedily(Region &region,
898                             const FrozenRewritePatternSet &patterns,
899                             GreedyRewriteConfig config, bool *changed) {
900   // The top-level operation must be known to be isolated from above to
901   // prevent performing canonicalizations on operations defined at or above
902   // the region containing 'op'.
903   assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
904          "patterns can only be applied to operations IsolatedFromAbove");
905 
906   // Set scope if not specified.
907   if (!config.scope)
908     config.scope = &region;
909 
910 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
911   if (failed(verify(config.scope->getParentOp())))
912     llvm::report_fatal_error(
913         "greedy pattern rewriter input IR failed to verify");
914 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
915 
916   // Start the pattern driver.
917   RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
918                                     region);
919   LogicalResult converged = std::move(driver).simplify(changed);
920   LLVM_DEBUG(if (failed(converged)) {
921     llvm::dbgs() << "The pattern rewrite did not converge after scanning "
922                  << config.maxIterations << " times\n";
923   });
924   return converged;
925 }
926 
927 //===----------------------------------------------------------------------===//
928 // MultiOpPatternRewriteDriver
929 //===----------------------------------------------------------------------===//
930 
931 namespace {
932 /// This driver simplfies a list of ops.
933 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
934 public:
935   explicit MultiOpPatternRewriteDriver(
936       MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
937       const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
938       llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
939 
940   /// Simplify `ops`. Return `success` if the transformation converged.
941   LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
942 
943 private:
944   void notifyOperationErased(Operation *op) override {
945     GreedyPatternRewriteDriver::notifyOperationErased(op);
946     if (survivingOps)
947       survivingOps->erase(op);
948   }
949 
950   /// An optional set of ops that survived the rewrite. This set is populated
951   /// at the beginning of `simplifyLocally` with the inititally provided list
952   /// of ops.
953   llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
954 };
955 } // namespace
956 
957 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
958     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
959     const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
960     llvm::SmallDenseSet<Operation *, 4> *survivingOps)
961     : GreedyPatternRewriteDriver(ctx, patterns, config),
962       survivingOps(survivingOps) {
963   if (config.strictMode != GreedyRewriteStrictness::AnyOp)
964     strictModeFilteredOps.insert(ops.begin(), ops.end());
965 
966   if (survivingOps) {
967     survivingOps->clear();
968     survivingOps->insert(ops.begin(), ops.end());
969   }
970 }
971 
972 LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
973                                                     bool *changed) && {
974   // Populate the initial worklist.
975   for (Operation *op : ops)
976     addSingleOpToWorklist(op);
977 
978   // Process ops on the worklist.
979   bool result = processWorklist();
980   if (changed)
981     *changed = result;
982 
983   return success(worklist.empty());
984 }
985 
986 /// Find the region that is the closest common ancestor of all given ops.
987 ///
988 /// Note: This function returns `nullptr` if there is a top-level op among the
989 /// given list of ops.
990 static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
991   assert(!ops.empty() && "expected at least one op");
992   // Fast path in case there is only one op.
993   if (ops.size() == 1)
994     return ops.front()->getParentRegion();
995 
996   Region *region = ops.front()->getParentRegion();
997   ops = ops.drop_front();
998   int sz = ops.size();
999   llvm::BitVector remainingOps(sz, true);
1000   while (region) {
1001     int pos = -1;
1002     // Iterate over all remaining ops.
1003     while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
1004       // Is this op contained in `region`?
1005       if (region->findAncestorOpInRegion(*ops[pos]))
1006         remainingOps.reset(pos);
1007     }
1008     if (remainingOps.none())
1009       break;
1010     region = region->getParentRegion();
1011   }
1012   return region;
1013 }
1014 
1015 LogicalResult mlir::applyOpPatternsGreedily(
1016     ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
1017     GreedyRewriteConfig config, bool *changed, bool *allErased) {
1018   if (ops.empty()) {
1019     if (changed)
1020       *changed = false;
1021     if (allErased)
1022       *allErased = true;
1023     return success();
1024   }
1025 
1026   // Determine scope of rewrite.
1027   if (!config.scope) {
1028     // Compute scope if none was provided. The scope will remain `nullptr` if
1029     // there is a top-level op among `ops`.
1030     config.scope = findCommonAncestor(ops);
1031   } else {
1032     // If a scope was provided, make sure that all ops are in scope.
1033 #ifndef NDEBUG
1034     bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
1035       return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
1036     });
1037     assert(allOpsInScope && "ops must be within the specified scope");
1038 #endif // NDEBUG
1039   }
1040 
1041 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1042   if (config.scope && failed(verify(config.scope->getParentOp())))
1043     llvm::report_fatal_error(
1044         "greedy pattern rewriter input IR failed to verify");
1045 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1046 
1047   // Start the pattern driver.
1048   llvm::SmallDenseSet<Operation *, 4> surviving;
1049   MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
1050                                      config, ops,
1051                                      allErased ? &surviving : nullptr);
1052   LogicalResult converged = std::move(driver).simplify(ops, changed);
1053   if (allErased)
1054     *allErased = surviving.empty();
1055   LLVM_DEBUG(if (failed(converged)) {
1056     llvm::dbgs() << "The pattern rewrite did not converge after "
1057                  << config.maxNumRewrites << " rewrites";
1058   });
1059   return converged;
1060 }
1061