xref: /llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (revision 4f4e2abb1a5ff1225d32410fd02b732d077aa056)
164d52014SChris Lattner //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
264d52014SChris Lattner //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
664d52014SChris Lattner //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
864d52014SChris Lattner //
909dfc571SJacques Pienaar // This file implements mlir::applyPatternsGreedily.
1064d52014SChris Lattner //
1164d52014SChris Lattner //===----------------------------------------------------------------------===//
1264d52014SChris Lattner 
13b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14e6d90a0dSMatthias Springer 
15e6d90a0dSMatthias Springer #include "mlir/Config/mlir-config.h"
1687e6e490SMehdi Amini #include "mlir/IR/Action.h"
17af371f9fSRiver Riddle #include "mlir/IR/Matchers.h"
1873b86d1bSMatthias Springer #include "mlir/IR/Verifier.h"
19eb623ae8SStephen Neuendorffer #include "mlir/Interfaces/SideEffectInterfaces.h"
20b6eb26fdSRiver Riddle #include "mlir/Rewrite/PatternApplicator.h"
211982afb1SRiver Riddle #include "mlir/Transforms/FoldUtils.h"
22fafb708bSRiver Riddle #include "mlir/Transforms/RegionUtils.h"
23e195e6baSMatthias Springer #include "llvm/ADT/BitVector.h"
2464d52014SChris Lattner #include "llvm/ADT/DenseMap.h"
25774416bdSMatthias Springer #include "llvm/ADT/ScopeExit.h"
265c757087SFeng Liu #include "llvm/Support/CommandLine.h"
275c757087SFeng Liu #include "llvm/Support/Debug.h"
285652ecc3SRiver Riddle #include "llvm/Support/ScopedPrinter.h"
295c757087SFeng Liu #include "llvm/Support/raw_ostream.h"
304e40c832SLei Zhang 
31ce954e1cSMatthias Springer #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
32ce954e1cSMatthias Springer #include <random>
33ce954e1cSMatthias Springer #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
34ce954e1cSMatthias Springer 
3564d52014SChris Lattner using namespace mlir;
3664d52014SChris Lattner 
375652ecc3SRiver Riddle #define DEBUG_TYPE "greedy-rewriter"
385c757087SFeng Liu 
39ca7167d5SMatthias Springer namespace {
40ca7167d5SMatthias Springer 
4104b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
42e6d90a0dSMatthias Springer // Debugging Infrastructure
4304b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
4404b5274eSUday Bondhugula 
455e10a8c4SMatthias Springer #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
46a02a0e80SMatthias Springer /// A helper struct that performs various "expensive checks" to detect broken
47a02a0e80SMatthias Springer /// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
48a02a0e80SMatthias Springer /// broken if:
49a02a0e80SMatthias Springer /// * IR does not verify after pattern application / folding.
50a02a0e80SMatthias Springer /// * Pattern returns "failure" but the IR has changed.
51a02a0e80SMatthias Springer /// * Pattern returns "success" but the IR has not changed.
52a02a0e80SMatthias Springer ///
53a02a0e80SMatthias Springer /// This struct stores finger prints of ops to determine whether the IR has
54a02a0e80SMatthias Springer /// changed or not.
55a02a0e80SMatthias Springer struct ExpensiveChecks : public RewriterBase::ForwardingListener {
56a02a0e80SMatthias Springer   ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
57a02a0e80SMatthias Springer       : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
58e6d90a0dSMatthias Springer 
59e6d90a0dSMatthias Springer   /// Compute finger prints of the given op and its nested ops.
60e6d90a0dSMatthias Springer   void computeFingerPrints(Operation *topLevel) {
61e6d90a0dSMatthias Springer     this->topLevel = topLevel;
62e6d90a0dSMatthias Springer     this->topLevelFingerPrint.emplace(topLevel);
635fdf8c6fSMatthias Springer     topLevel->walk([&](Operation *op) {
645fdf8c6fSMatthias Springer       fingerprints.try_emplace(op, op, /*includeNested=*/false);
655fdf8c6fSMatthias Springer     });
66e6d90a0dSMatthias Springer   }
67e6d90a0dSMatthias Springer 
68e6d90a0dSMatthias Springer   /// Clear all finger prints.
69e6d90a0dSMatthias Springer   void clear() {
70e6d90a0dSMatthias Springer     topLevel = nullptr;
71e6d90a0dSMatthias Springer     topLevelFingerPrint.reset();
72e6d90a0dSMatthias Springer     fingerprints.clear();
73e6d90a0dSMatthias Springer   }
74e6d90a0dSMatthias Springer 
75e6d90a0dSMatthias Springer   void notifyRewriteSuccess() {
76a02a0e80SMatthias Springer     if (!topLevel)
77a02a0e80SMatthias Springer       return;
78a02a0e80SMatthias Springer 
79a02a0e80SMatthias Springer     // Make sure that the IR still verifies.
80a02a0e80SMatthias Springer     if (failed(verify(topLevel)))
81a02a0e80SMatthias Springer       llvm::report_fatal_error("IR failed to verify after pattern application");
82a02a0e80SMatthias Springer 
83e6d90a0dSMatthias Springer     // Pattern application success => IR must have changed.
84e6d90a0dSMatthias Springer     OperationFingerPrint afterFingerPrint(topLevel);
85e6d90a0dSMatthias Springer     if (*topLevelFingerPrint == afterFingerPrint) {
86e6d90a0dSMatthias Springer       // Note: Run "mlir-opt -debug" to see which pattern is broken.
87e6d90a0dSMatthias Springer       llvm::report_fatal_error(
88e6d90a0dSMatthias Springer           "pattern returned success but IR did not change");
89e6d90a0dSMatthias Springer     }
90e6d90a0dSMatthias Springer     for (const auto &it : fingerprints) {
91e6d90a0dSMatthias Springer       // Skip top-level op, its finger print is never invalidated.
92e6d90a0dSMatthias Springer       if (it.first == topLevel)
93e6d90a0dSMatthias Springer         continue;
94e6d90a0dSMatthias Springer       // Note: Finger print computation may crash when an op was erased
95e6d90a0dSMatthias Springer       // without notifying the rewriter. (Run with ASAN to see where the op was
96e6d90a0dSMatthias Springer       // erased; the op was probably erased directly, bypassing the rewriter
97e6d90a0dSMatthias Springer       // API.) Finger print computation does may not crash if a new op was
98e6d90a0dSMatthias Springer       // created at the same memory location. (But then the finger print should
99e6d90a0dSMatthias Springer       // have changed.)
1005fdf8c6fSMatthias Springer       if (it.second !=
1015fdf8c6fSMatthias Springer           OperationFingerPrint(it.first, /*includeNested=*/false)) {
102e6d90a0dSMatthias Springer         // Note: Run "mlir-opt -debug" to see which pattern is broken.
103e6d90a0dSMatthias Springer         llvm::report_fatal_error("operation finger print changed");
104e6d90a0dSMatthias Springer       }
105e6d90a0dSMatthias Springer     }
106e6d90a0dSMatthias Springer   }
107e6d90a0dSMatthias Springer 
108e6d90a0dSMatthias Springer   void notifyRewriteFailure() {
109a02a0e80SMatthias Springer     if (!topLevel)
110a02a0e80SMatthias Springer       return;
111a02a0e80SMatthias Springer 
112e6d90a0dSMatthias Springer     // Pattern application failure => IR must not have changed.
113e6d90a0dSMatthias Springer     OperationFingerPrint afterFingerPrint(topLevel);
114e6d90a0dSMatthias Springer     if (*topLevelFingerPrint != afterFingerPrint) {
115e6d90a0dSMatthias Springer       // Note: Run "mlir-opt -debug" to see which pattern is broken.
116e6d90a0dSMatthias Springer       llvm::report_fatal_error("pattern returned failure but IR did change");
117e6d90a0dSMatthias Springer     }
118e6d90a0dSMatthias Springer   }
119e6d90a0dSMatthias Springer 
120a02a0e80SMatthias Springer   void notifyFoldingSuccess() {
121a02a0e80SMatthias Springer     if (!topLevel)
122a02a0e80SMatthias Springer       return;
123a02a0e80SMatthias Springer 
124a02a0e80SMatthias Springer     // Make sure that the IR still verifies.
125a02a0e80SMatthias Springer     if (failed(verify(topLevel)))
126a02a0e80SMatthias Springer       llvm::report_fatal_error("IR failed to verify after folding");
127a02a0e80SMatthias Springer   }
128a02a0e80SMatthias Springer 
129e6d90a0dSMatthias Springer protected:
130e6d90a0dSMatthias Springer   /// Invalidate the finger print of the given op, i.e., remove it from the map.
1315fdf8c6fSMatthias Springer   void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
1325fdf8c6fSMatthias Springer 
133914e6074SMatthias Springer   void notifyBlockErased(Block *block) override {
134914e6074SMatthias Springer     RewriterBase::ForwardingListener::notifyBlockErased(block);
1355fdf8c6fSMatthias Springer 
1365fdf8c6fSMatthias Springer     // The block structure (number of blocks, types of block arguments, etc.)
1375fdf8c6fSMatthias Springer     // is part of the fingerprint of the parent op.
1385fdf8c6fSMatthias Springer     // TODO: The parent op fingerprint should also be invalidated when modifying
1395fdf8c6fSMatthias Springer     // the block arguments of a block, but we do not have a
1405fdf8c6fSMatthias Springer     // `notifyBlockModified` callback yet.
1415fdf8c6fSMatthias Springer     invalidateFingerPrint(block->getParentOp());
142e6d90a0dSMatthias Springer   }
143e6d90a0dSMatthias Springer 
144c5edef62SMatthias Springer   void notifyOperationInserted(Operation *op,
145c5edef62SMatthias Springer                                OpBuilder::InsertPoint previous) override {
1465cc0f76dSMatthias Springer     RewriterBase::ForwardingListener::notifyOperationInserted(op, previous);
147e6d90a0dSMatthias Springer     invalidateFingerPrint(op->getParentOp());
148e6d90a0dSMatthias Springer   }
149e6d90a0dSMatthias Springer 
150e6d90a0dSMatthias Springer   void notifyOperationModified(Operation *op) override {
151e6d90a0dSMatthias Springer     RewriterBase::ForwardingListener::notifyOperationModified(op);
152e6d90a0dSMatthias Springer     invalidateFingerPrint(op);
153e6d90a0dSMatthias Springer   }
154e6d90a0dSMatthias Springer 
155914e6074SMatthias Springer   void notifyOperationErased(Operation *op) override {
156914e6074SMatthias Springer     RewriterBase::ForwardingListener::notifyOperationErased(op);
157e6d90a0dSMatthias Springer     op->walk([this](Operation *op) { invalidateFingerPrint(op); });
158e6d90a0dSMatthias Springer   }
159e6d90a0dSMatthias Springer 
160e6d90a0dSMatthias Springer   /// Operation finger prints to detect invalid pattern API usage. IR is checked
161e6d90a0dSMatthias Springer   /// against these finger prints after pattern application to detect cases
162e6d90a0dSMatthias Springer   /// where IR was modified directly, bypassing the rewriter API.
163e6d90a0dSMatthias Springer   DenseMap<Operation *, OperationFingerPrint> fingerprints;
164e6d90a0dSMatthias Springer 
165e6d90a0dSMatthias Springer   /// Top-level operation of the current greedy rewrite.
166e6d90a0dSMatthias Springer   Operation *topLevel = nullptr;
167e6d90a0dSMatthias Springer 
168e6d90a0dSMatthias Springer   /// Finger print of the top-level operation.
169e6d90a0dSMatthias Springer   std::optional<OperationFingerPrint> topLevelFingerPrint;
170e6d90a0dSMatthias Springer };
171e6d90a0dSMatthias Springer #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
172e6d90a0dSMatthias Springer 
173dec908a2SMatthias Springer #ifndef NDEBUG
174dec908a2SMatthias Springer static Operation *getDumpRootOp(Operation *op) {
175dec908a2SMatthias Springer   // Dump the parent op so that materialized constants are visible. If the op
176dec908a2SMatthias Springer   // is a top-level op, dump it directly.
177dec908a2SMatthias Springer   if (Operation *parentOp = op->getParentOp())
178dec908a2SMatthias Springer     return parentOp;
179dec908a2SMatthias Springer   return op;
180dec908a2SMatthias Springer }
181dec908a2SMatthias Springer static void logSuccessfulFolding(Operation *op) {
182dec908a2SMatthias Springer   llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
183dec908a2SMatthias Springer   op->dump();
184dec908a2SMatthias Springer   llvm::dbgs() << "\n\n";
185dec908a2SMatthias Springer }
186dec908a2SMatthias Springer #endif // NDEBUG
187dec908a2SMatthias Springer 
188e6d90a0dSMatthias Springer //===----------------------------------------------------------------------===//
189ca7167d5SMatthias Springer // Worklist
190ca7167d5SMatthias Springer //===----------------------------------------------------------------------===//
191ca7167d5SMatthias Springer 
192ca7167d5SMatthias Springer /// A LIFO worklist of operations with efficient removal and set semantics.
193ca7167d5SMatthias Springer ///
194ca7167d5SMatthias Springer /// This class maintains a vector of operations and a mapping of operations to
195ca7167d5SMatthias Springer /// positions in the vector, so that operations can be removed efficiently at
196ca7167d5SMatthias Springer /// random. When an operation is removed, it is replaced with nullptr. Such
197ca7167d5SMatthias Springer /// nullptr are skipped when pop'ing elements.
198ca7167d5SMatthias Springer class Worklist {
199ca7167d5SMatthias Springer public:
200ca7167d5SMatthias Springer   Worklist();
201ca7167d5SMatthias Springer 
202ca7167d5SMatthias Springer   /// Clear the worklist.
203ca7167d5SMatthias Springer   void clear();
204ca7167d5SMatthias Springer 
205ca7167d5SMatthias Springer   /// Return whether the worklist is empty.
206ca7167d5SMatthias Springer   bool empty() const;
207ca7167d5SMatthias Springer 
208ca7167d5SMatthias Springer   /// Push an operation to the end of the worklist, unless the operation is
209ca7167d5SMatthias Springer   /// already on the worklist.
210ca7167d5SMatthias Springer   void push(Operation *op);
211ca7167d5SMatthias Springer 
212ca7167d5SMatthias Springer   /// Pop the an operation from the end of the worklist. Only allowed on
213ca7167d5SMatthias Springer   /// non-empty worklists.
214ca7167d5SMatthias Springer   Operation *pop();
215ca7167d5SMatthias Springer 
216ca7167d5SMatthias Springer   /// Remove an operation from the worklist.
217ca7167d5SMatthias Springer   void remove(Operation *op);
218ca7167d5SMatthias Springer 
219ca7167d5SMatthias Springer   /// Reverse the worklist.
220ca7167d5SMatthias Springer   void reverse();
221ca7167d5SMatthias Springer 
222ce954e1cSMatthias Springer protected:
223ca7167d5SMatthias Springer   /// The worklist of operations.
224ca7167d5SMatthias Springer   std::vector<Operation *> list;
225ca7167d5SMatthias Springer 
226ca7167d5SMatthias Springer   /// A mapping of operations to positions in `list`.
227ca7167d5SMatthias Springer   DenseMap<Operation *, unsigned> map;
228ca7167d5SMatthias Springer };
229ca7167d5SMatthias Springer 
230ca7167d5SMatthias Springer Worklist::Worklist() { list.reserve(64); }
231ca7167d5SMatthias Springer 
232ca7167d5SMatthias Springer void Worklist::clear() {
233ca7167d5SMatthias Springer   list.clear();
234ca7167d5SMatthias Springer   map.clear();
235ca7167d5SMatthias Springer }
236ca7167d5SMatthias Springer 
237ca7167d5SMatthias Springer bool Worklist::empty() const {
238ca7167d5SMatthias Springer   // Skip all nullptr.
239ca7167d5SMatthias Springer   return !llvm::any_of(list,
240ca7167d5SMatthias Springer                        [](Operation *op) { return static_cast<bool>(op); });
241ca7167d5SMatthias Springer }
242ca7167d5SMatthias Springer 
243ca7167d5SMatthias Springer void Worklist::push(Operation *op) {
244ca7167d5SMatthias Springer   assert(op && "cannot push nullptr to worklist");
245ca7167d5SMatthias Springer   // Check to see if the worklist already contains this op.
24660c5c4ccSMehdi Amini   if (!map.insert({op, list.size()}).second)
247ca7167d5SMatthias Springer     return;
248ca7167d5SMatthias Springer   list.push_back(op);
249ca7167d5SMatthias Springer }
250ca7167d5SMatthias Springer 
251ca7167d5SMatthias Springer Operation *Worklist::pop() {
252ca7167d5SMatthias Springer   assert(!empty() && "cannot pop from empty worklist");
253ca7167d5SMatthias Springer   // Skip and remove all trailing nullptr.
254ca7167d5SMatthias Springer   while (!list.back())
255ca7167d5SMatthias Springer     list.pop_back();
256ca7167d5SMatthias Springer   Operation *op = list.back();
257ca7167d5SMatthias Springer   list.pop_back();
258ca7167d5SMatthias Springer   map.erase(op);
259ca7167d5SMatthias Springer   // Cleanup: Remove all trailing nullptr.
260ca7167d5SMatthias Springer   while (!list.empty() && !list.back())
261ca7167d5SMatthias Springer     list.pop_back();
262ca7167d5SMatthias Springer   return op;
263ca7167d5SMatthias Springer }
264ca7167d5SMatthias Springer 
265ca7167d5SMatthias Springer void Worklist::remove(Operation *op) {
266ca7167d5SMatthias Springer   assert(op && "cannot remove nullptr from worklist");
267ca7167d5SMatthias Springer   auto it = map.find(op);
268ca7167d5SMatthias Springer   if (it != map.end()) {
269ca7167d5SMatthias Springer     assert(list[it->second] == op && "malformed worklist data structure");
270ca7167d5SMatthias Springer     list[it->second] = nullptr;
271ca7167d5SMatthias Springer     map.erase(it);
272ca7167d5SMatthias Springer   }
273ca7167d5SMatthias Springer }
274ca7167d5SMatthias Springer 
275ca7167d5SMatthias Springer void Worklist::reverse() {
276ca7167d5SMatthias Springer   std::reverse(list.begin(), list.end());
277ca7167d5SMatthias Springer   for (size_t i = 0, e = list.size(); i != e; ++i)
278ca7167d5SMatthias Springer     map[list[i]] = i;
279ca7167d5SMatthias Springer }
280ca7167d5SMatthias Springer 
281ce954e1cSMatthias Springer #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
282ce954e1cSMatthias Springer /// A worklist that pops elements at a random position. This worklist is for
283ce954e1cSMatthias Springer /// testing/debugging purposes only. It can be used to ensure that lowering
284ce954e1cSMatthias Springer /// pipelines work correctly regardless of the order in which ops are processed
285ce954e1cSMatthias Springer /// by the GreedyPatternRewriteDriver.
286ce954e1cSMatthias Springer class RandomizedWorklist : public Worklist {
287ce954e1cSMatthias Springer public:
288ce954e1cSMatthias Springer   RandomizedWorklist() : Worklist() {
289ce954e1cSMatthias Springer     generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
290ce954e1cSMatthias Springer   }
291ce954e1cSMatthias Springer 
292ce954e1cSMatthias Springer   /// Pop a random non-empty op from the worklist.
293ce954e1cSMatthias Springer   Operation *pop() {
294ce954e1cSMatthias Springer     Operation *op = nullptr;
295ce954e1cSMatthias Springer     do {
296ce954e1cSMatthias Springer       assert(!list.empty() && "cannot pop from empty worklist");
297ce954e1cSMatthias Springer       int64_t pos = generator() % list.size();
298ce954e1cSMatthias Springer       op = list[pos];
299ce954e1cSMatthias Springer       list.erase(list.begin() + pos);
300ce954e1cSMatthias Springer       for (int64_t i = pos, e = list.size(); i < e; ++i)
301ce954e1cSMatthias Springer         map[list[i]] = i;
302ce954e1cSMatthias Springer       map.erase(op);
303ce954e1cSMatthias Springer     } while (!op);
304ce954e1cSMatthias Springer     return op;
305ce954e1cSMatthias Springer   }
306ce954e1cSMatthias Springer 
307ce954e1cSMatthias Springer private:
308ce954e1cSMatthias Springer   std::minstd_rand0 generator;
309ce954e1cSMatthias Springer };
310ce954e1cSMatthias Springer #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
311ce954e1cSMatthias Springer 
312ca7167d5SMatthias Springer //===----------------------------------------------------------------------===//
313e6d90a0dSMatthias Springer // GreedyPatternRewriteDriver
314e6d90a0dSMatthias Springer //===----------------------------------------------------------------------===//
315e6d90a0dSMatthias Springer 
31664d52014SChris Lattner /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
3179d5c63f6SMatthias Springer /// applies the locally optimal patterns.
3189d5c63f6SMatthias Springer ///
3199d5c63f6SMatthias Springer /// This abstract class manages the worklist and contains helper methods for
3209d5c63f6SMatthias Springer /// rewriting ops on the worklist. Derived classes specify how ops are added
3219d5c63f6SMatthias Springer /// to the worklist in the beginning.
3226b3e0002SMatthias Springer class GreedyPatternRewriteDriver : public RewriterBase::Listener {
3239d5c63f6SMatthias Springer protected:
3242566a72aSRiver Riddle   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
325648f34a2SChris Lattner                                       const FrozenRewritePatternSet &patterns,
326977cddb9SMatthias Springer                                       const GreedyRewriteConfig &config);
3273e98fbf4SRiver Riddle 
3289d5c63f6SMatthias Springer   /// Add the given operation to the worklist.
3299d5c63f6SMatthias Springer   void addSingleOpToWorklist(Operation *op);
33064d52014SChris Lattner 
331e195e6baSMatthias Springer   /// Add the given operation and its ancestors to the worklist.
332e195e6baSMatthias Springer   void addToWorklist(Operation *op);
3335c4f1fddSRiver Riddle 
3349d5c63f6SMatthias Springer   /// Notify the driver that the specified operation may have been modified
3359d5c63f6SMatthias Springer   /// in-place. The operation is added to the worklist.
336bafc4dfcSMatthias Springer   void notifyOperationModified(Operation *op) override;
3372fea658aSMatthias Springer 
3389d5c63f6SMatthias Springer   /// Notify the driver that the specified operation was inserted. Update the
3399d5c63f6SMatthias Springer   /// worklist as needed: The operation is enqueued depending on scope and
3409d5c63f6SMatthias Springer   /// strict mode.
3416b3e0002SMatthias Springer   void notifyOperationInserted(Operation *op,
3426b3e0002SMatthias Springer                                OpBuilder::InsertPoint previous) override;
34364d52014SChris Lattner 
3449d5c63f6SMatthias Springer   /// Notify the driver that the specified operation was removed. Update the
3459d5c63f6SMatthias Springer   /// worklist as needed: The operation and its children are removed from the
3469d5c63f6SMatthias Springer   /// worklist.
347914e6074SMatthias Springer   void notifyOperationErased(Operation *op) override;
34864d52014SChris Lattner 
3499d5c63f6SMatthias Springer   /// Notify the driver that the specified operation was replaced. Update the
3509d5c63f6SMatthias Springer   /// worklist as needed: New users are added enqueued.
351c6532830SMatthias Springer   void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
352085b687fSChris Lattner 
3539d5c63f6SMatthias Springer   /// Process ops until the worklist is empty or `config.maxNumRewrites` is
3549d5c63f6SMatthias Springer   /// reached. Return `true` if any IR was changed.
3559d5c63f6SMatthias Springer   bool processWorklist();
3564bd9f936SChris Lattner 
3576b3e0002SMatthias Springer   /// The pattern rewriter that is used for making IR modifications and is
3586b3e0002SMatthias Springer   /// passed to rewrite patterns.
3596b3e0002SMatthias Springer   PatternRewriter rewriter;
3606b3e0002SMatthias Springer 
3614bd9f936SChris Lattner   /// The worklist for this transformation keeps track of the operations that
362ca7167d5SMatthias Springer   /// need to be (re)visited.
363ce954e1cSMatthias Springer #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
364ce954e1cSMatthias Springer   RandomizedWorklist worklist;
365ce954e1cSMatthias Springer #else
366ca7167d5SMatthias Springer   Worklist worklist;
367ce954e1cSMatthias Springer #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
36860a29837SRiver Riddle 
36964716b2cSChris Lattner   /// Configuration information for how to simplify.
37067760d7eSMatthias Springer   const GreedyRewriteConfig config;
3715652ecc3SRiver Riddle 
3726bdecbcbSMatthias Springer   /// The list of ops we are restricting our rewrites to. These include the
3736bdecbcbSMatthias Springer   /// supplied set of ops as well as new ops created while rewriting those ops
3746bdecbcbSMatthias Springer   /// depending on `strictMode`. This set is not maintained when
3756bdecbcbSMatthias Springer   /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
3766bdecbcbSMatthias Springer   llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
3776bdecbcbSMatthias Springer 
378e195e6baSMatthias Springer private:
3799d5c63f6SMatthias Springer   /// Look over the provided operands for any defining operations that should
3809d5c63f6SMatthias Springer   /// be re-added to the worklist. This function should be called when an
3819d5c63f6SMatthias Springer   /// operation is modified or removed, as it may trigger further
3829d5c63f6SMatthias Springer   /// simplifications.
383ddc98929Smlevesquedion   void addOperandsToWorklist(Operation *op);
3849d5c63f6SMatthias Springer 
3853ed98cb3SMatthias Springer   /// Notify the driver that the given block was inserted.
3863ed98cb3SMatthias Springer   void notifyBlockInserted(Block *block, Region *previous,
3873ed98cb3SMatthias Springer                            Region::iterator previousIt) override;
388279c1d2bSMatthias Springer 
38962bf7710SMatthias Springer   /// Notify the driver that the given block is about to be removed.
390914e6074SMatthias Springer   void notifyBlockErased(Block *block) override;
39162bf7710SMatthias Springer 
3929d5c63f6SMatthias Springer   /// For debugging only: Notify the driver of a pattern match failure.
3939a028afdSMatthias Springer   void
3949d5c63f6SMatthias Springer   notifyMatchFailure(Location loc,
3959d5c63f6SMatthias Springer                      function_ref<void(Diagnostic &)> reasonCallback) override;
3969d5c63f6SMatthias Springer 
3975652ecc3SRiver Riddle #ifndef NDEBUG
3985652ecc3SRiver Riddle   /// A logger used to emit information during the application process.
3995652ecc3SRiver Riddle   llvm::ScopedPrinter logger{llvm::dbgs()};
4005652ecc3SRiver Riddle #endif
4019d5c63f6SMatthias Springer 
4029d5c63f6SMatthias Springer   /// The low-level pattern applicator.
4039d5c63f6SMatthias Springer   PatternApplicator matcher;
404e6d90a0dSMatthias Springer 
4055e10a8c4SMatthias Springer #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
406a02a0e80SMatthias Springer   ExpensiveChecks expensiveChecks;
407e6d90a0dSMatthias Springer #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
40864d52014SChris Lattner };
409be0a7e9fSMehdi Amini } // namespace
41064d52014SChris Lattner 
411b7144ab7SRiver Riddle GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
412b7144ab7SRiver Riddle     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
413977cddb9SMatthias Springer     const GreedyRewriteConfig &config)
4146b3e0002SMatthias Springer     : rewriter(ctx), config(config), matcher(patterns)
4155e10a8c4SMatthias Springer #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
416e6d90a0dSMatthias Springer       // clang-format off
417a02a0e80SMatthias Springer       , expensiveChecks(
418a02a0e80SMatthias Springer           /*driver=*/this,
419a02a0e80SMatthias Springer           /*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr)
420e6d90a0dSMatthias Springer // clang-format on
421e6d90a0dSMatthias Springer #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
422e6d90a0dSMatthias Springer {
423b7144ab7SRiver Riddle   // Apply a simple cost model based solely on pattern benefit.
424b7144ab7SRiver Riddle   matcher.applyDefaultCostModel();
425c6532830SMatthias Springer 
426c6532830SMatthias Springer   // Set up listener.
4275e10a8c4SMatthias Springer #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
428e6d90a0dSMatthias Springer   // Send IR notifications to the debug handler. This handler will then forward
429e6d90a0dSMatthias Springer   // all notifications to this GreedyPatternRewriteDriver.
4306b3e0002SMatthias Springer   rewriter.setListener(&expensiveChecks);
431e6d90a0dSMatthias Springer #else
4326b3e0002SMatthias Springer   rewriter.setListener(this);
433e6d90a0dSMatthias Springer #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
434b7144ab7SRiver Riddle }
435b7144ab7SRiver Riddle 
4369d5c63f6SMatthias Springer bool GreedyPatternRewriteDriver::processWorklist() {
4375652ecc3SRiver Riddle #ifndef NDEBUG
4385652ecc3SRiver Riddle   const char *logLineComment =
4395652ecc3SRiver Riddle       "//===-------------------------------------------===//\n";
4405652ecc3SRiver Riddle 
4415652ecc3SRiver Riddle   /// A utility function to log a process result for the given reason.
4425652ecc3SRiver Riddle   auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
4435652ecc3SRiver Riddle     logger.unindent();
4445652ecc3SRiver Riddle     logger.startLine() << "} -> " << result;
4455652ecc3SRiver Riddle     if (!msg.isTriviallyEmpty())
4465652ecc3SRiver Riddle       logger.getOStream() << " : " << msg;
4475652ecc3SRiver Riddle     logger.getOStream() << "\n";
4485652ecc3SRiver Riddle   };
4495652ecc3SRiver Riddle   auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
4505652ecc3SRiver Riddle     logResult(result, msg);
4515652ecc3SRiver Riddle     logger.startLine() << logLineComment;
4525652ecc3SRiver Riddle   };
4535652ecc3SRiver Riddle #endif
4545652ecc3SRiver Riddle 
4555c757087SFeng Liu   bool changed = false;
456391cb541SMatthias Springer   int64_t numRewrites = 0;
4570ff3cf0cSMatthias Springer   while (!worklist.empty() &&
4580ff3cf0cSMatthias Springer          (numRewrites < config.maxNumRewrites ||
4590ff3cf0cSMatthias Springer           config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
460ca7167d5SMatthias Springer     auto *op = worklist.pop();
46164d52014SChris Lattner 
4625652ecc3SRiver Riddle     LLVM_DEBUG({
4635652ecc3SRiver Riddle       logger.getOStream() << "\n";
4645652ecc3SRiver Riddle       logger.startLine() << logLineComment;
4659d5c63f6SMatthias Springer       logger.startLine() << "Processing operation : '" << op->getName() << "'("
4669d5c63f6SMatthias Springer                          << op << ") {\n";
4675652ecc3SRiver Riddle       logger.indent();
4685652ecc3SRiver Riddle 
4695652ecc3SRiver Riddle       // If the operation has no regions, just print it here.
4705652ecc3SRiver Riddle       if (op->getNumRegions() == 0) {
4715652ecc3SRiver Riddle         op->print(
4725652ecc3SRiver Riddle             logger.startLine(),
4735652ecc3SRiver Riddle             OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
4745652ecc3SRiver Riddle         logger.getOStream() << "\n\n";
4755652ecc3SRiver Riddle       }
4765652ecc3SRiver Riddle     });
4775652ecc3SRiver Riddle 
4780ddba0bdSRiver Riddle     // If the operation is trivially dead - remove it.
4790ddba0bdSRiver Riddle     if (isOpTriviallyDead(op)) {
4806b3e0002SMatthias Springer       rewriter.eraseOp(op);
481f875e55bSUday Bondhugula       changed = true;
4825652ecc3SRiver Riddle 
4835652ecc3SRiver Riddle       LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
48464d52014SChris Lattner       continue;
48564d52014SChris Lattner     }
48664d52014SChris Lattner 
487bb6d5c22SMatthias Springer     // Try to fold this op. Do not fold constant ops. That would lead to an
488bb6d5c22SMatthias Springer     // infinite folding loop, as every constant op would be folded to an
489bb6d5c22SMatthias Springer     // Attribute and then immediately be rematerialized as a constant op, which
490bb6d5c22SMatthias Springer     // is then put on the worklist.
49109dfc571SJacques Pienaar     if (config.fold && !op->hasTrait<OpTrait::ConstantLike>()) {
492bb6d5c22SMatthias Springer       SmallVector<OpFoldResult> foldResults;
493bb6d5c22SMatthias Springer       if (succeeded(op->fold(foldResults))) {
4945652ecc3SRiver Riddle         LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
495dec908a2SMatthias Springer #ifndef NDEBUG
496dec908a2SMatthias Springer         Operation *dumpRootOp = getDumpRootOp(op);
497dec908a2SMatthias Springer #endif // NDEBUG
498bb6d5c22SMatthias Springer         if (foldResults.empty()) {
499bb6d5c22SMatthias Springer           // Op was modified in-place.
500bb6d5c22SMatthias Springer           notifyOperationModified(op);
501eb42868fSBilly Zhu           changed = true;
502dec908a2SMatthias Springer           LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
50373b86d1bSMatthias Springer #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
504a02a0e80SMatthias Springer           expensiveChecks.notifyFoldingSuccess();
50573b86d1bSMatthias Springer #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
506934b6d12SChris Lattner           continue;
50764d52014SChris Lattner         }
50864d52014SChris Lattner 
509bb6d5c22SMatthias Springer         // Op results can be replaced with `foldResults`.
510bb6d5c22SMatthias Springer         assert(foldResults.size() == op->getNumResults() &&
511bb6d5c22SMatthias Springer                "folder produced incorrect number of results");
5126b3e0002SMatthias Springer         OpBuilder::InsertionGuard g(rewriter);
5136b3e0002SMatthias Springer         rewriter.setInsertionPoint(op);
514bb6d5c22SMatthias Springer         SmallVector<Value> replacements;
515eb42868fSBilly Zhu         bool materializationSucceeded = true;
516bb6d5c22SMatthias Springer         for (auto [ofr, resultType] :
517bb6d5c22SMatthias Springer              llvm::zip_equal(foldResults, op->getResultTypes())) {
518bb6d5c22SMatthias Springer           if (auto value = ofr.dyn_cast<Value>()) {
519bb6d5c22SMatthias Springer             assert(value.getType() == resultType &&
520bb6d5c22SMatthias Springer                    "folder produced value of incorrect type");
521bb6d5c22SMatthias Springer             replacements.push_back(value);
522bb6d5c22SMatthias Springer             continue;
523bb6d5c22SMatthias Springer           }
524bb6d5c22SMatthias Springer           // Materialize Attributes as SSA values.
525bb6d5c22SMatthias Springer           Operation *constOp = op->getDialect()->materializeConstant(
526*4f4e2abbSKazu Hirata               rewriter, cast<Attribute>(ofr), resultType, op->getLoc());
527eb42868fSBilly Zhu 
528eb42868fSBilly Zhu           if (!constOp) {
529eb42868fSBilly Zhu             // If materialization fails, cleanup any operations generated for
530eb42868fSBilly Zhu             // the previous results.
531eb42868fSBilly Zhu             llvm::SmallDenseSet<Operation *> replacementOps;
532eb42868fSBilly Zhu             for (Value replacement : replacements) {
533eb42868fSBilly Zhu               assert(replacement.use_empty() &&
534eb42868fSBilly Zhu                      "folder reused existing op for one result but constant "
535eb42868fSBilly Zhu                      "materialization failed for another result");
536eb42868fSBilly Zhu               replacementOps.insert(replacement.getDefiningOp());
537eb42868fSBilly Zhu             }
538eb42868fSBilly Zhu             for (Operation *op : replacementOps) {
5396b3e0002SMatthias Springer               rewriter.eraseOp(op);
540eb42868fSBilly Zhu             }
541eb42868fSBilly Zhu 
542eb42868fSBilly Zhu             materializationSucceeded = false;
543eb42868fSBilly Zhu             break;
544eb42868fSBilly Zhu           }
545eb42868fSBilly Zhu 
546bb6d5c22SMatthias Springer           assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
547bb6d5c22SMatthias Springer                  "materializeConstant produced op that is not a ConstantLike");
548bb6d5c22SMatthias Springer           assert(constOp->getResultTypes()[0] == resultType &&
549bb6d5c22SMatthias Springer                  "materializeConstant produced incorrect result type");
550bb6d5c22SMatthias Springer           replacements.push_back(constOp->getResult(0));
551bb6d5c22SMatthias Springer         }
552eb42868fSBilly Zhu 
553eb42868fSBilly Zhu         if (materializationSucceeded) {
5546b3e0002SMatthias Springer           rewriter.replaceOp(op, replacements);
555eb42868fSBilly Zhu           changed = true;
556dec908a2SMatthias Springer           LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
557bb6d5c22SMatthias Springer #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
558a02a0e80SMatthias Springer           expensiveChecks.notifyFoldingSuccess();
559bb6d5c22SMatthias Springer #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
560bb6d5c22SMatthias Springer           continue;
561bb6d5c22SMatthias Springer         }
562bb6d5c22SMatthias Springer       }
563eb42868fSBilly Zhu     }
564bb6d5c22SMatthias Springer 
56532052c84SRiver Riddle     // Try to match one of the patterns. The rewriter is automatically
566648f34a2SChris Lattner     // notified of any necessary changes, so there is nothing else to do
567648f34a2SChris Lattner     // here.
5689b6bd709SMatthias Springer     auto canApplyCallback = [&](const Pattern &pattern) {
5695652ecc3SRiver Riddle       LLVM_DEBUG({
5705652ecc3SRiver Riddle         logger.getOStream() << "\n";
5715652ecc3SRiver Riddle         logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
5725652ecc3SRiver Riddle                            << op->getName() << " -> (";
5735652ecc3SRiver Riddle         llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
5745652ecc3SRiver Riddle         logger.getOStream() << ")' {\n";
5755652ecc3SRiver Riddle         logger.indent();
5765652ecc3SRiver Riddle       });
5779b6bd709SMatthias Springer       if (config.listener)
5789b6bd709SMatthias Springer         config.listener->notifyPatternBegin(pattern, op);
5795652ecc3SRiver Riddle       return true;
5805652ecc3SRiver Riddle     };
5819b6bd709SMatthias Springer     function_ref<bool(const Pattern &)> canApply = canApplyCallback;
5829b6bd709SMatthias Springer     auto onFailureCallback = [&](const Pattern &pattern) {
5835652ecc3SRiver Riddle       LLVM_DEBUG(logResult("failure", "pattern failed to match"));
5849b6bd709SMatthias Springer       if (config.listener)
5859b6bd709SMatthias Springer         config.listener->notifyPatternEnd(pattern, failure());
5865652ecc3SRiver Riddle     };
5879b6bd709SMatthias Springer     function_ref<void(const Pattern &)> onFailure = onFailureCallback;
5889b6bd709SMatthias Springer     auto onSuccessCallback = [&](const Pattern &pattern) {
5895652ecc3SRiver Riddle       LLVM_DEBUG(logResult("success", "pattern applied successfully"));
5909b6bd709SMatthias Springer       if (config.listener)
5919b6bd709SMatthias Springer         config.listener->notifyPatternEnd(pattern, success());
5925652ecc3SRiver Riddle       return success();
5935652ecc3SRiver Riddle     };
5949b6bd709SMatthias Springer     function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
5959b6bd709SMatthias Springer 
5969b6bd709SMatthias Springer #ifdef NDEBUG
5979b6bd709SMatthias Springer     // Optimization: PatternApplicator callbacks are not needed when running in
5989b6bd709SMatthias Springer     // optimized mode and without a listener.
5999b6bd709SMatthias Springer     if (!config.listener) {
6009b6bd709SMatthias Springer       canApply = nullptr;
6019b6bd709SMatthias Springer       onFailure = nullptr;
6029b6bd709SMatthias Springer       onSuccess = nullptr;
6039b6bd709SMatthias Springer     }
6049b6bd709SMatthias Springer #endif // NDEBUG
6055652ecc3SRiver Riddle 
6065e10a8c4SMatthias Springer #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
60773b86d1bSMatthias Springer     if (config.scope) {
608a02a0e80SMatthias Springer       expensiveChecks.computeFingerPrints(config.scope->getParentOp());
60973b86d1bSMatthias Springer     }
610e6d90a0dSMatthias Springer     auto clearFingerprints =
611a02a0e80SMatthias Springer         llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
612e6d90a0dSMatthias Springer #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
613e6d90a0dSMatthias Springer 
6145652ecc3SRiver Riddle     LogicalResult matchResult =
6156b3e0002SMatthias Springer         matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
6160ff3cf0cSMatthias Springer 
617391cb541SMatthias Springer     if (succeeded(matchResult)) {
618aa051a09SMatthias Springer       LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
6195e10a8c4SMatthias Springer #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
620a02a0e80SMatthias Springer       expensiveChecks.notifyRewriteSuccess();
621e6d90a0dSMatthias Springer #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
622391cb541SMatthias Springer       changed = true;
6230ff3cf0cSMatthias Springer       ++numRewrites;
624aa051a09SMatthias Springer     } else {
625aa051a09SMatthias Springer       LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
6265e10a8c4SMatthias Springer #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
627a02a0e80SMatthias Springer       expensiveChecks.notifyRewriteFailure();
628e6d90a0dSMatthias Springer #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
629391cb541SMatthias Springer     }
63064d52014SChris Lattner   }
631a32f0dcbSRiver Riddle 
6329d5c63f6SMatthias Springer   return changed;
63364d52014SChris Lattner }
63464d52014SChris Lattner 
635b7144ab7SRiver Riddle void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
636e6d90a0dSMatthias Springer   assert(op && "expected valid op");
637ed9194beSMatthias Springer   // Gather potential ancestors while looking for a "scope" parent region.
638ed9194beSMatthias Springer   SmallVector<Operation *, 8> ancestors;
639724a0e2cSMatthias Springer   Region *region = nullptr;
640724a0e2cSMatthias Springer   do {
641ed9194beSMatthias Springer     ancestors.push_back(op);
642724a0e2cSMatthias Springer     region = op->getParentRegion();
643977cddb9SMatthias Springer     if (config.scope == region) {
644724a0e2cSMatthias Springer       // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
645ed9194beSMatthias Springer       for (Operation *op : ancestors)
646ed9194beSMatthias Springer         addSingleOpToWorklist(op);
647724a0e2cSMatthias Springer       return;
648ed9194beSMatthias Springer     }
649724a0e2cSMatthias Springer     if (region == nullptr)
650724a0e2cSMatthias Springer       return;
651724a0e2cSMatthias Springer   } while ((op = region->getParentOp()));
652ed9194beSMatthias Springer }
653ed9194beSMatthias Springer 
654ed9194beSMatthias Springer void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
6556bdecbcbSMatthias Springer   if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
656ca7167d5SMatthias Springer       strictModeFilteredOps.contains(op))
657ca7167d5SMatthias Springer     worklist.push(op);
658b7144ab7SRiver Riddle }
659b7144ab7SRiver Riddle 
6603ed98cb3SMatthias Springer void GreedyPatternRewriteDriver::notifyBlockInserted(
6613ed98cb3SMatthias Springer     Block *block, Region *previous, Region::iterator previousIt) {
662279c1d2bSMatthias Springer   if (config.listener)
6633ed98cb3SMatthias Springer     config.listener->notifyBlockInserted(block, previous, previousIt);
664279c1d2bSMatthias Springer }
665279c1d2bSMatthias Springer 
666914e6074SMatthias Springer void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
66762bf7710SMatthias Springer   if (config.listener)
668914e6074SMatthias Springer     config.listener->notifyBlockErased(block);
66962bf7710SMatthias Springer }
67062bf7710SMatthias Springer 
6716b3e0002SMatthias Springer void GreedyPatternRewriteDriver::notifyOperationInserted(
6726b3e0002SMatthias Springer     Operation *op, OpBuilder::InsertPoint previous) {
6735652ecc3SRiver Riddle   LLVM_DEBUG({
6745652ecc3SRiver Riddle     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
6755652ecc3SRiver Riddle                        << ")\n";
6765652ecc3SRiver Riddle   });
677279c1d2bSMatthias Springer   if (config.listener)
6785cc0f76dSMatthias Springer     config.listener->notifyOperationInserted(op, previous);
6796bdecbcbSMatthias Springer   if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
6806bdecbcbSMatthias Springer     strictModeFilteredOps.insert(op);
681b7144ab7SRiver Riddle   addToWorklist(op);
682b7144ab7SRiver Riddle }
683b7144ab7SRiver Riddle 
684bafc4dfcSMatthias Springer void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
6852fea658aSMatthias Springer   LLVM_DEBUG({
6862fea658aSMatthias Springer     logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
6872fea658aSMatthias Springer                        << ")\n";
6882fea658aSMatthias Springer   });
689daf41890SIngo Müller   if (config.listener)
690daf41890SIngo Müller     config.listener->notifyOperationModified(op);
6912fea658aSMatthias Springer   addToWorklist(op);
6922fea658aSMatthias Springer }
6932fea658aSMatthias Springer 
694ddc98929Smlevesquedion void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) {
695ddc98929Smlevesquedion   for (Value operand : op->getOperands()) {
696ddc98929Smlevesquedion     // If this operand currently has at most 2 users, add its defining op to the
697ddc98929Smlevesquedion     // worklist. Indeed, after the op is deleted, then the operand will have at
698ddc98929Smlevesquedion     // most 1 user left. If it has 0 users left, it can be deleted too,
699ddc98929Smlevesquedion     // and if it has 1 user left, there may be further canonicalization
700ddc98929Smlevesquedion     // opportunities.
701ddc98929Smlevesquedion     if (!operand)
702b7144ab7SRiver Riddle       continue;
703ddc98929Smlevesquedion 
704ddc98929Smlevesquedion     auto *defOp = operand.getDefiningOp();
705ddc98929Smlevesquedion     if (!defOp)
706ddc98929Smlevesquedion       continue;
707ddc98929Smlevesquedion 
708ddc98929Smlevesquedion     Operation *otherUser = nullptr;
709ddc98929Smlevesquedion     bool hasMoreThanTwoUses = false;
710ddc98929Smlevesquedion     for (auto user : operand.getUsers()) {
711ddc98929Smlevesquedion       if (user == op || user == otherUser)
712ddc98929Smlevesquedion         continue;
713ddc98929Smlevesquedion       if (!otherUser) {
714ddc98929Smlevesquedion         otherUser = user;
715ddc98929Smlevesquedion         continue;
716ddc98929Smlevesquedion       }
717ddc98929Smlevesquedion       hasMoreThanTwoUses = true;
718ddc98929Smlevesquedion       break;
719ddc98929Smlevesquedion     }
720ddc98929Smlevesquedion     if (hasMoreThanTwoUses)
721ddc98929Smlevesquedion       continue;
722ddc98929Smlevesquedion 
723b7144ab7SRiver Riddle     addToWorklist(defOp);
724b7144ab7SRiver Riddle   }
725b7144ab7SRiver Riddle }
726b7144ab7SRiver Riddle 
727914e6074SMatthias Springer void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
7286e5021b8SMatthias Springer   LLVM_DEBUG({
7296e5021b8SMatthias Springer     logger.startLine() << "** Erase   : '" << op->getName() << "'(" << op
7306e5021b8SMatthias Springer                        << ")\n";
7316e5021b8SMatthias Springer   });
73273b86d1bSMatthias Springer 
73373b86d1bSMatthias Springer #ifndef NDEBUG
73473b86d1bSMatthias Springer   // Only ops that are within the configured scope are added to the worklist of
73573b86d1bSMatthias Springer   // the greedy pattern rewriter. Moreover, the parent op of the scope region is
73673b86d1bSMatthias Springer   // the part of the IR that is taken into account for the "expensive checks".
73773b86d1bSMatthias Springer   // A greedy pattern rewrite is not allowed to erase the parent op of the scope
73873b86d1bSMatthias Springer   // region, as that would break the worklist handling and the expensive checks.
73973b86d1bSMatthias Springer   if (config.scope && config.scope->getParentOp() == op)
74073b86d1bSMatthias Springer     llvm_unreachable(
74173b86d1bSMatthias Springer         "scope region must not be erased during greedy pattern rewrite");
74273b86d1bSMatthias Springer #endif // NDEBUG
74373b86d1bSMatthias Springer 
744279c1d2bSMatthias Springer   if (config.listener)
745914e6074SMatthias Springer     config.listener->notifyOperationErased(op);
7466e5021b8SMatthias Springer 
747ddc98929Smlevesquedion   addOperandsToWorklist(op);
748695a5a6aSMatthias Springer   worklist.remove(op);
7496bdecbcbSMatthias Springer 
7506bdecbcbSMatthias Springer   if (config.strictMode != GreedyRewriteStrictness::AnyOp)
7516bdecbcbSMatthias Springer     strictModeFilteredOps.erase(op);
752b7144ab7SRiver Riddle }
753b7144ab7SRiver Riddle 
754c6532830SMatthias Springer void GreedyPatternRewriteDriver::notifyOperationReplaced(
755c6532830SMatthias Springer     Operation *op, ValueRange replacement) {
7565652ecc3SRiver Riddle   LLVM_DEBUG({
7575652ecc3SRiver Riddle     logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
7585652ecc3SRiver Riddle                        << ")\n";
7595652ecc3SRiver Riddle   });
760279c1d2bSMatthias Springer   if (config.listener)
761279c1d2bSMatthias Springer     config.listener->notifyOperationReplaced(op, replacement);
762b7144ab7SRiver Riddle }
763b7144ab7SRiver Riddle 
7649a028afdSMatthias Springer void GreedyPatternRewriteDriver::notifyMatchFailure(
765ea64828aSRiver Riddle     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
7665652ecc3SRiver Riddle   LLVM_DEBUG({
767ea64828aSRiver Riddle     Diagnostic diag(loc, DiagnosticSeverity::Remark);
7685652ecc3SRiver Riddle     reasonCallback(diag);
7699b6bd709SMatthias Springer     logger.startLine() << "** Match Failure : " << diag.str() << "\n";
7705652ecc3SRiver Riddle   });
771279c1d2bSMatthias Springer   if (config.listener)
7729a028afdSMatthias Springer     config.listener->notifyMatchFailure(loc, reasonCallback);
7735652ecc3SRiver Riddle }
7745652ecc3SRiver Riddle 
7759d5c63f6SMatthias Springer //===----------------------------------------------------------------------===//
7769d5c63f6SMatthias Springer // RegionPatternRewriteDriver
7779d5c63f6SMatthias Springer //===----------------------------------------------------------------------===//
7789d5c63f6SMatthias Springer 
7799d5c63f6SMatthias Springer namespace {
7809d5c63f6SMatthias Springer /// This driver simplfies all ops in a region.
7819d5c63f6SMatthias Springer class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
7829d5c63f6SMatthias Springer public:
7839d5c63f6SMatthias Springer   explicit RegionPatternRewriteDriver(MLIRContext *ctx,
7849d5c63f6SMatthias Springer                                       const FrozenRewritePatternSet &patterns,
7859d5c63f6SMatthias Springer                                       const GreedyRewriteConfig &config,
7869d5c63f6SMatthias Springer                                       Region &regions);
7879d5c63f6SMatthias Springer 
7889d5c63f6SMatthias Springer   /// Simplify ops inside `region` and simplify the region itself. Return
7899d5c63f6SMatthias Springer   /// success if the transformation converged.
7908498c9e9SJoel Wee   LogicalResult simplify(bool *changed) &&;
7919d5c63f6SMatthias Springer 
7929d5c63f6SMatthias Springer private:
7939d5c63f6SMatthias Springer   /// The region that is simplified.
7949d5c63f6SMatthias Springer   Region &region;
7959d5c63f6SMatthias Springer };
7969d5c63f6SMatthias Springer } // namespace
7979d5c63f6SMatthias Springer 
7989d5c63f6SMatthias Springer RegionPatternRewriteDriver::RegionPatternRewriteDriver(
7999d5c63f6SMatthias Springer     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
8009d5c63f6SMatthias Springer     const GreedyRewriteConfig &config, Region &region)
8019d5c63f6SMatthias Springer     : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
8029d5c63f6SMatthias Springer   // Populate strict mode ops.
8039d5c63f6SMatthias Springer   if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
8049d5c63f6SMatthias Springer     region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
8059d5c63f6SMatthias Springer   }
8069d5c63f6SMatthias Springer }
8079d5c63f6SMatthias Springer 
80887e6e490SMehdi Amini namespace {
80987e6e490SMehdi Amini class GreedyPatternRewriteIteration
81087e6e490SMehdi Amini     : public tracing::ActionImpl<GreedyPatternRewriteIteration> {
81187e6e490SMehdi Amini public:
81287e6e490SMehdi Amini   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration)
81387e6e490SMehdi Amini   GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration)
81487e6e490SMehdi Amini       : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
81587e6e490SMehdi Amini         iteration(iteration) {}
81687e6e490SMehdi Amini   static constexpr StringLiteral tag = "GreedyPatternRewriteIteration";
81787e6e490SMehdi Amini   void print(raw_ostream &os) const override {
81887e6e490SMehdi Amini     os << "GreedyPatternRewriteIteration(" << iteration << ")";
81987e6e490SMehdi Amini   }
82087e6e490SMehdi Amini 
82187e6e490SMehdi Amini private:
82287e6e490SMehdi Amini   int64_t iteration = 0;
82387e6e490SMehdi Amini };
82487e6e490SMehdi Amini } // namespace
82587e6e490SMehdi Amini 
8268498c9e9SJoel Wee LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
8278498c9e9SJoel Wee   bool continueRewrites = false;
8289d5c63f6SMatthias Springer   int64_t iteration = 0;
8296b3e0002SMatthias Springer   MLIRContext *ctx = rewriter.getContext();
8309d5c63f6SMatthias Springer   do {
8319d5c63f6SMatthias Springer     // Check if the iteration limit was reached.
8328498c9e9SJoel Wee     if (++iteration > config.maxIterations &&
8339d5c63f6SMatthias Springer         config.maxIterations != GreedyRewriteConfig::kNoLimit)
8349d5c63f6SMatthias Springer       break;
8359d5c63f6SMatthias Springer 
836bb6d5c22SMatthias Springer     // New iteration: start with an empty worklist.
8379d5c63f6SMatthias Springer     worklist.clear();
8389d5c63f6SMatthias Springer 
839bb6d5c22SMatthias Springer     // `OperationFolder` CSE's constant ops (and may move them into parents
840bb6d5c22SMatthias Springer     // regions to enable more aggressive CSE'ing).
8416b3e0002SMatthias Springer     OperationFolder folder(ctx, this);
842bb6d5c22SMatthias Springer     auto insertKnownConstant = [&](Operation *op) {
843bb6d5c22SMatthias Springer       // Check for existing constants when populating the worklist. This avoids
844bb6d5c22SMatthias Springer       // accidentally reversing the constant order during processing.
845bb6d5c22SMatthias Springer       Attribute constValue;
846bb6d5c22SMatthias Springer       if (matchPattern(op, m_Constant(&constValue)))
847bb6d5c22SMatthias Springer         if (!folder.insertKnownConstant(op, constValue))
848bb6d5c22SMatthias Springer           return true;
849bb6d5c22SMatthias Springer       return false;
850bb6d5c22SMatthias Springer     };
851bb6d5c22SMatthias Springer 
8529d5c63f6SMatthias Springer     if (!config.useTopDownTraversal) {
8539d5c63f6SMatthias Springer       // Add operations to the worklist in postorder.
8549d5c63f6SMatthias Springer       region.walk([&](Operation *op) {
85509dfc571SJacques Pienaar         if (!config.cseConstants || !insertKnownConstant(op))
8569d5c63f6SMatthias Springer           addToWorklist(op);
8579d5c63f6SMatthias Springer       });
8589d5c63f6SMatthias Springer     } else {
8599d5c63f6SMatthias Springer       // Add all nested operations to the worklist in preorder.
8609d5c63f6SMatthias Springer       region.walk<WalkOrder::PreOrder>([&](Operation *op) {
86109dfc571SJacques Pienaar         if (!config.cseConstants || !insertKnownConstant(op)) {
8629bdfa8dfSMatthias Springer           addToWorklist(op);
8639d5c63f6SMatthias Springer           return WalkResult::advance();
8649d5c63f6SMatthias Springer         }
8659d5c63f6SMatthias Springer         return WalkResult::skip();
8669d5c63f6SMatthias Springer       });
8679d5c63f6SMatthias Springer 
8689d5c63f6SMatthias Springer       // Reverse the list so our pop-back loop processes them in-order.
869ca7167d5SMatthias Springer       worklist.reverse();
8709d5c63f6SMatthias Springer     }
8719d5c63f6SMatthias Springer 
87287e6e490SMehdi Amini     ctx->executeAction<GreedyPatternRewriteIteration>(
87387e6e490SMehdi Amini         [&] {
8748498c9e9SJoel Wee           continueRewrites = processWorklist();
8759d5c63f6SMatthias Springer 
87687e6e490SMehdi Amini           // After applying patterns, make sure that the CFG of each of the
87787e6e490SMehdi Amini           // regions is kept up to date.
878a506279eSMehdi Amini           if (config.enableRegionSimplification !=
879a506279eSMehdi Amini               GreedySimplifyRegionLevel::Disabled) {
880a506279eSMehdi Amini             continueRewrites |= succeeded(simplifyRegions(
881a506279eSMehdi Amini                 rewriter, region,
882a506279eSMehdi Amini                 /*mergeBlocks=*/config.enableRegionSimplification ==
883a506279eSMehdi Amini                     GreedySimplifyRegionLevel::Aggressive));
884a506279eSMehdi Amini           }
88587e6e490SMehdi Amini         },
88687e6e490SMehdi Amini         {&region}, iteration);
8878498c9e9SJoel Wee   } while (continueRewrites);
8888498c9e9SJoel Wee 
8898498c9e9SJoel Wee   if (changed)
8908498c9e9SJoel Wee     *changed = iteration > 1;
8919d5c63f6SMatthias Springer 
8929d5c63f6SMatthias Springer   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
8938498c9e9SJoel Wee   return success(!continueRewrites);
8949d5c63f6SMatthias Springer }
8959d5c63f6SMatthias Springer 
8963e98fbf4SRiver Riddle LogicalResult
89709dfc571SJacques Pienaar mlir::applyPatternsGreedily(Region &region,
8980b20413eSUday Bondhugula                             const FrozenRewritePatternSet &patterns,
8998498c9e9SJoel Wee                             GreedyRewriteConfig config, bool *changed) {
900e7a2ef21SRiver Riddle   // The top-level operation must be known to be isolated from above to
901e7a2ef21SRiver Riddle   // prevent performing canonicalizations on operations defined at or above
902e7a2ef21SRiver Riddle   // the region containing 'op'.
903a2b837abSMatthias Springer   assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
9046b1cc3c6SRiver Riddle          "patterns can only be applied to operations IsolatedFromAbove");
905e7a2ef21SRiver Riddle 
906977cddb9SMatthias Springer   // Set scope if not specified.
907977cddb9SMatthias Springer   if (!config.scope)
908977cddb9SMatthias Springer     config.scope = &region;
909977cddb9SMatthias Springer 
91073b86d1bSMatthias Springer #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
91173b86d1bSMatthias Springer   if (failed(verify(config.scope->getParentOp())))
91273b86d1bSMatthias Springer     llvm::report_fatal_error(
91373b86d1bSMatthias Springer         "greedy pattern rewriter input IR failed to verify");
91473b86d1bSMatthias Springer #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
91573b86d1bSMatthias Springer 
9166b1cc3c6SRiver Riddle   // Start the pattern driver.
9179d5c63f6SMatthias Springer   RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
9189d5c63f6SMatthias Springer                                     region);
9198498c9e9SJoel Wee   LogicalResult converged = std::move(driver).simplify(changed);
9209d5c63f6SMatthias Springer   LLVM_DEBUG(if (failed(converged)) {
9210ff3cf0cSMatthias Springer     llvm::dbgs() << "The pattern rewrite did not converge after scanning "
92264716b2cSChris Lattner                  << config.maxIterations << " times\n";
9235c757087SFeng Liu   });
9249d5c63f6SMatthias Springer   return converged;
92564d52014SChris Lattner }
92604b5274eSUday Bondhugula 
92704b5274eSUday Bondhugula //===----------------------------------------------------------------------===//
9287932d21fSUday Bondhugula // MultiOpPatternRewriteDriver
9297932d21fSUday Bondhugula //===----------------------------------------------------------------------===//
9307932d21fSUday Bondhugula 
9317932d21fSUday Bondhugula namespace {
9329d5c63f6SMatthias Springer /// This driver simplfies a list of ops.
9337932d21fSUday Bondhugula class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
9347932d21fSUday Bondhugula public:
93567760d7eSMatthias Springer   explicit MultiOpPatternRewriteDriver(
93667760d7eSMatthias Springer       MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
9379d5c63f6SMatthias Springer       const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
9389d5c63f6SMatthias Springer       llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
9397932d21fSUday Bondhugula 
9409d5c63f6SMatthias Springer   /// Simplify `ops`. Return `success` if the transformation converged.
9419d5c63f6SMatthias Springer   LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
9427932d21fSUday Bondhugula 
943ba3a9f51SChia-hung Duan private:
944914e6074SMatthias Springer   void notifyOperationErased(Operation *op) override {
945914e6074SMatthias Springer     GreedyPatternRewriteDriver::notifyOperationErased(op);
946774416bdSMatthias Springer     if (survivingOps)
947774416bdSMatthias Springer       survivingOps->erase(op);
9487932d21fSUday Bondhugula   }
9497932d21fSUday Bondhugula 
950774416bdSMatthias Springer   /// An optional set of ops that survived the rewrite. This set is populated
951774416bdSMatthias Springer   /// at the beginning of `simplifyLocally` with the inititally provided list
952774416bdSMatthias Springer   /// of ops.
95367760d7eSMatthias Springer   llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
9547932d21fSUday Bondhugula };
955be0a7e9fSMehdi Amini } // namespace
9567932d21fSUday Bondhugula 
9579d5c63f6SMatthias Springer MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
9589d5c63f6SMatthias Springer     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
9599d5c63f6SMatthias Springer     const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
9609d5c63f6SMatthias Springer     llvm::SmallDenseSet<Operation *, 4> *survivingOps)
9619d5c63f6SMatthias Springer     : GreedyPatternRewriteDriver(ctx, patterns, config),
9629d5c63f6SMatthias Springer       survivingOps(survivingOps) {
9639d5c63f6SMatthias Springer   if (config.strictMode != GreedyRewriteStrictness::AnyOp)
9649d5c63f6SMatthias Springer     strictModeFilteredOps.insert(ops.begin(), ops.end());
9659d5c63f6SMatthias Springer 
96667760d7eSMatthias Springer   if (survivingOps) {
967774416bdSMatthias Springer     survivingOps->clear();
968774416bdSMatthias Springer     survivingOps->insert(ops.begin(), ops.end());
969774416bdSMatthias Springer   }
9707932d21fSUday Bondhugula }
9717932d21fSUday Bondhugula 
9729d5c63f6SMatthias Springer LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
9739d5c63f6SMatthias Springer                                                     bool *changed) && {
9749d5c63f6SMatthias Springer   // Populate the initial worklist.
9757932d21fSUday Bondhugula   for (Operation *op : ops)
976e195e6baSMatthias Springer     addSingleOpToWorklist(op);
9777932d21fSUday Bondhugula 
9789d5c63f6SMatthias Springer   // Process ops on the worklist.
9799d5c63f6SMatthias Springer   bool result = processWorklist();
980fefe655bSMatthias Springer   if (changed)
9819d5c63f6SMatthias Springer     *changed = result;
9827932d21fSUday Bondhugula 
983fefe655bSMatthias Springer   return success(worklist.empty());
9847932d21fSUday Bondhugula }
9857932d21fSUday Bondhugula 
986e195e6baSMatthias Springer /// Find the region that is the closest common ancestor of all given ops.
987724a0e2cSMatthias Springer ///
988724a0e2cSMatthias Springer /// Note: This function returns `nullptr` if there is a top-level op among the
989724a0e2cSMatthias Springer /// given list of ops.
990e195e6baSMatthias Springer static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
991e195e6baSMatthias Springer   assert(!ops.empty() && "expected at least one op");
992e195e6baSMatthias Springer   // Fast path in case there is only one op.
993e195e6baSMatthias Springer   if (ops.size() == 1)
994e195e6baSMatthias Springer     return ops.front()->getParentRegion();
995e195e6baSMatthias Springer 
996e195e6baSMatthias Springer   Region *region = ops.front()->getParentRegion();
997e195e6baSMatthias Springer   ops = ops.drop_front();
998e195e6baSMatthias Springer   int sz = ops.size();
999e195e6baSMatthias Springer   llvm::BitVector remainingOps(sz, true);
1000724a0e2cSMatthias Springer   while (region) {
1001e195e6baSMatthias Springer     int pos = -1;
1002e195e6baSMatthias Springer     // Iterate over all remaining ops.
1003e195e6baSMatthias Springer     while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
1004e195e6baSMatthias Springer       // Is this op contained in `region`?
1005e195e6baSMatthias Springer       if (region->findAncestorOpInRegion(*ops[pos]))
1006e195e6baSMatthias Springer         remainingOps.reset(pos);
1007e195e6baSMatthias Springer     }
1008e195e6baSMatthias Springer     if (remainingOps.none())
1009e195e6baSMatthias Springer       break;
1010724a0e2cSMatthias Springer     region = region->getParentRegion();
1011724a0e2cSMatthias Springer   }
1012e195e6baSMatthias Springer   return region;
1013e195e6baSMatthias Springer }
1014e195e6baSMatthias Springer 
101509dfc571SJacques Pienaar LogicalResult mlir::applyOpPatternsGreedily(
1016977cddb9SMatthias Springer     ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
10176bdecbcbSMatthias Springer     GreedyRewriteConfig config, bool *changed, bool *allErased) {
1018fefe655bSMatthias Springer   if (ops.empty()) {
1019fefe655bSMatthias Springer     if (changed)
1020fefe655bSMatthias Springer       *changed = false;
1021774416bdSMatthias Springer     if (allErased)
1022774416bdSMatthias Springer       *allErased = true;
1023fefe655bSMatthias Springer     return success();
1024fefe655bSMatthias Springer   }
10257932d21fSUday Bondhugula 
1026977cddb9SMatthias Springer   // Determine scope of rewrite.
1027977cddb9SMatthias Springer   if (!config.scope) {
1028724a0e2cSMatthias Springer     // Compute scope if none was provided. The scope will remain `nullptr` if
1029724a0e2cSMatthias Springer     // there is a top-level op among `ops`.
1030977cddb9SMatthias Springer     config.scope = findCommonAncestor(ops);
1031e195e6baSMatthias Springer   } else {
1032e195e6baSMatthias Springer     // If a scope was provided, make sure that all ops are in scope.
1033e195e6baSMatthias Springer #ifndef NDEBUG
1034e195e6baSMatthias Springer     bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
1035977cddb9SMatthias Springer       return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
1036e195e6baSMatthias Springer     });
1037e195e6baSMatthias Springer     assert(allOpsInScope && "ops must be within the specified scope");
1038e195e6baSMatthias Springer #endif // NDEBUG
1039e195e6baSMatthias Springer   }
1040e195e6baSMatthias Springer 
104173b86d1bSMatthias Springer #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
104273b86d1bSMatthias Springer   if (config.scope && failed(verify(config.scope->getParentOp())))
104373b86d1bSMatthias Springer     llvm::report_fatal_error(
104473b86d1bSMatthias Springer         "greedy pattern rewriter input IR failed to verify");
104573b86d1bSMatthias Springer #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
104673b86d1bSMatthias Springer 
10477932d21fSUday Bondhugula   // Start the pattern driver.
1048774416bdSMatthias Springer   llvm::SmallDenseSet<Operation *, 4> surviving;
104967760d7eSMatthias Springer   MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
10509d5c63f6SMatthias Springer                                      config, ops,
10519d5c63f6SMatthias Springer                                      allErased ? &surviving : nullptr);
10529d5c63f6SMatthias Springer   LogicalResult converged = std::move(driver).simplify(ops, changed);
1053774416bdSMatthias Springer   if (allErased)
1054774416bdSMatthias Springer     *allErased = surviving.empty();
1055cadd5666SMatthias Springer   LLVM_DEBUG(if (failed(converged)) {
1056cadd5666SMatthias Springer     llvm::dbgs() << "The pattern rewrite did not converge after "
1057977cddb9SMatthias Springer                  << config.maxNumRewrites << " rewrites";
1058cadd5666SMatthias Springer   });
1059774416bdSMatthias Springer   return converged;
10607932d21fSUday Bondhugula }
1061