xref: /llvm-project/mlir/include/mlir/IR/PatternMatch.h (revision 0f8a6b7d03550cb58cf49535af2de2230abfe997)
1 //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_PATTERNMATCH_H
10 #define MLIR_IR_PATTERNMATCH_H
11 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/Support/TypeName.h"
16 #include <optional>
17 
18 using llvm::SmallPtrSetImpl;
19 namespace mlir {
20 
21 class PatternRewriter;
22 
23 //===----------------------------------------------------------------------===//
24 // PatternBenefit class
25 //===----------------------------------------------------------------------===//
26 
27 /// This class represents the benefit of a pattern match in a unitless scheme
28 /// that ranges from 0 (very little benefit) to 65K.  The most common unit to
29 /// use here is the "number of operations matched" by the pattern.
30 ///
31 /// This also has a sentinel representation that can be used for patterns that
32 /// fail to match.
33 ///
34 class PatternBenefit {
35   enum { ImpossibleToMatchSentinel = 65535 };
36 
37 public:
38   PatternBenefit() = default;
39   PatternBenefit(unsigned benefit);
40   PatternBenefit(const PatternBenefit &) = default;
41   PatternBenefit &operator=(const PatternBenefit &) = default;
42 
43   static PatternBenefit impossibleToMatch() { return PatternBenefit(); }
44   bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
45 
46   /// If the corresponding pattern can match, return its benefit.  If the
47   // corresponding pattern isImpossibleToMatch() then this aborts.
48   unsigned short getBenefit() const;
49 
50   bool operator==(const PatternBenefit &rhs) const {
51     return representation == rhs.representation;
52   }
53   bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
54   bool operator<(const PatternBenefit &rhs) const {
55     return representation < rhs.representation;
56   }
57   bool operator>(const PatternBenefit &rhs) const { return rhs < *this; }
58   bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); }
59   bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); }
60 
61 private:
62   unsigned short representation{ImpossibleToMatchSentinel};
63 };
64 
65 //===----------------------------------------------------------------------===//
66 // Pattern
67 //===----------------------------------------------------------------------===//
68 
69 /// This class contains all of the data related to a pattern, but does not
70 /// contain any methods or logic for the actual matching. This class is solely
71 /// used to interface with the metadata of a pattern, such as the benefit or
72 /// root operation.
73 class Pattern {
74   /// This enum represents the kind of value used to select the root operations
75   /// that match this pattern.
76   enum class RootKind {
77     /// The pattern root matches "any" operation.
78     Any,
79     /// The pattern root is matched using a concrete operation name.
80     OperationName,
81     /// The pattern root is matched using an interface ID.
82     InterfaceID,
83     /// The patter root is matched using a trait ID.
84     TraitID
85   };
86 
87 public:
88   /// Return a list of operations that may be generated when rewriting an
89   /// operation instance with this pattern.
90   ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
91 
92   /// Return the root node that this pattern matches. Patterns that can match
93   /// multiple root types return std::nullopt.
94   std::optional<OperationName> getRootKind() const {
95     if (rootKind == RootKind::OperationName)
96       return OperationName::getFromOpaquePointer(rootValue);
97     return std::nullopt;
98   }
99 
100   /// Return the interface ID used to match the root operation of this pattern.
101   /// If the pattern does not use an interface ID for deciding the root match,
102   /// this returns std::nullopt.
103   std::optional<TypeID> getRootInterfaceID() const {
104     if (rootKind == RootKind::InterfaceID)
105       return TypeID::getFromOpaquePointer(rootValue);
106     return std::nullopt;
107   }
108 
109   /// Return the trait ID used to match the root operation of this pattern.
110   /// If the pattern does not use a trait ID for deciding the root match, this
111   /// returns std::nullopt.
112   std::optional<TypeID> getRootTraitID() const {
113     if (rootKind == RootKind::TraitID)
114       return TypeID::getFromOpaquePointer(rootValue);
115     return std::nullopt;
116   }
117 
118   /// Return the benefit (the inverse of "cost") of matching this pattern.  The
119   /// benefit of a Pattern is always static - rewrites that may have dynamic
120   /// benefit can be instantiated multiple times (different Pattern instances)
121   /// for each benefit that they may return, and be guarded by different match
122   /// condition predicates.
123   PatternBenefit getBenefit() const { return benefit; }
124 
125   /// Returns true if this pattern is known to result in recursive application,
126   /// i.e. this pattern may generate IR that also matches this pattern, but is
127   /// known to bound the recursion. This signals to a rewrite driver that it is
128   /// safe to apply this pattern recursively to generated IR.
129   bool hasBoundedRewriteRecursion() const {
130     return contextAndHasBoundedRecursion.getInt();
131   }
132 
133   /// Return the MLIRContext used to create this pattern.
134   MLIRContext *getContext() const {
135     return contextAndHasBoundedRecursion.getPointer();
136   }
137 
138   /// Return a readable name for this pattern. This name should only be used for
139   /// debugging purposes, and may be empty.
140   StringRef getDebugName() const { return debugName; }
141 
142   /// Set the human readable debug name used for this pattern. This name will
143   /// only be used for debugging purposes.
144   void setDebugName(StringRef name) { debugName = name; }
145 
146   /// Return the set of debug labels attached to this pattern.
147   ArrayRef<StringRef> getDebugLabels() const { return debugLabels; }
148 
149   /// Add the provided debug labels to this pattern.
150   void addDebugLabels(ArrayRef<StringRef> labels) {
151     debugLabels.append(labels.begin(), labels.end());
152   }
153   void addDebugLabels(StringRef label) { debugLabels.push_back(label); }
154 
155 protected:
156   /// This class acts as a special tag that makes the desire to match "any"
157   /// operation type explicit. This helps to avoid unnecessary usages of this
158   /// feature, and ensures that the user is making a conscious decision.
159   struct MatchAnyOpTypeTag {};
160   /// This class acts as a special tag that makes the desire to match any
161   /// operation that implements a given interface explicit. This helps to avoid
162   /// unnecessary usages of this feature, and ensures that the user is making a
163   /// conscious decision.
164   struct MatchInterfaceOpTypeTag {};
165   /// This class acts as a special tag that makes the desire to match any
166   /// operation that implements a given trait explicit. This helps to avoid
167   /// unnecessary usages of this feature, and ensures that the user is making a
168   /// conscious decision.
169   struct MatchTraitOpTypeTag {};
170 
171   /// Construct a pattern with a certain benefit that matches the operation
172   /// with the given root name.
173   Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context,
174           ArrayRef<StringRef> generatedNames = {});
175   /// Construct a pattern that may match any operation type. `generatedNames`
176   /// contains the names of operations that may be generated during a successful
177   /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
178   /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
179   /// always be supplied here.
180   Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
181           ArrayRef<StringRef> generatedNames = {});
182   /// Construct a pattern that may match any operation that implements the
183   /// interface defined by the provided `interfaceID`. `generatedNames` contains
184   /// the names of operations that may be generated during a successful rewrite.
185   /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
186   /// interface" behavior is what the user actually desired,
187   /// `MatchInterfaceOpTypeTag()` should always be supplied here.
188   Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
189           PatternBenefit benefit, MLIRContext *context,
190           ArrayRef<StringRef> generatedNames = {});
191   /// Construct a pattern that may match any operation that implements the
192   /// trait defined by the provided `traitID`. `generatedNames` contains the
193   /// names of operations that may be generated during a successful rewrite.
194   /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
195   /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
196   /// always be supplied here.
197   Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
198           MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
199 
200   /// Set the flag detailing if this pattern has bounded rewrite recursion or
201   /// not.
202   void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
203     contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
204   }
205 
206 private:
207   Pattern(const void *rootValue, RootKind rootKind,
208           ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
209           MLIRContext *context);
210 
211   /// The value used to match the root operation of the pattern.
212   const void *rootValue;
213   RootKind rootKind;
214 
215   /// The expected benefit of matching this pattern.
216   const PatternBenefit benefit;
217 
218   /// The context this pattern was created from, and a boolean flag indicating
219   /// whether this pattern has bounded recursion or not.
220   llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
221 
222   /// A list of the potential operations that may be generated when rewriting
223   /// an op with this pattern.
224   SmallVector<OperationName, 2> generatedOps;
225 
226   /// A readable name for this pattern. May be empty.
227   StringRef debugName;
228 
229   /// The set of debug labels attached to this pattern.
230   SmallVector<StringRef, 0> debugLabels;
231 };
232 
233 //===----------------------------------------------------------------------===//
234 // RewritePattern
235 //===----------------------------------------------------------------------===//
236 
237 /// RewritePattern is the common base class for all DAG to DAG replacements.
238 /// There are two possible usages of this class:
239 ///   * Multi-step RewritePattern with "match" and "rewrite"
240 ///     - By overloading the "match" and "rewrite" functions, the user can
241 ///       separate the concerns of matching and rewriting.
242 ///   * Single-step RewritePattern with "matchAndRewrite"
243 ///     - By overloading the "matchAndRewrite" function, the user can perform
244 ///       the rewrite in the same call as the match.
245 ///
246 class RewritePattern : public Pattern {
247 public:
248   virtual ~RewritePattern() = default;
249 
250   /// Rewrite the IR rooted at the specified operation with the result of
251   /// this pattern, generating any new operations with the specified
252   /// builder.  If an unexpected error is encountered (an internal
253   /// compiler error), it is emitted through the normal MLIR diagnostic
254   /// hooks and the IR is left in a valid state.
255   virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
256 
257   /// Attempt to match against code rooted at the specified operation,
258   /// which is the same operation code as getRootKind().
259   virtual LogicalResult match(Operation *op) const;
260 
261   /// Attempt to match against code rooted at the specified operation,
262   /// which is the same operation code as getRootKind(). If successful, this
263   /// function will automatically perform the rewrite.
264   virtual LogicalResult matchAndRewrite(Operation *op,
265                                         PatternRewriter &rewriter) const {
266     if (succeeded(match(op))) {
267       rewrite(op, rewriter);
268       return success();
269     }
270     return failure();
271   }
272 
273   /// This method provides a convenient interface for creating and initializing
274   /// derived rewrite patterns of the given type `T`.
275   template <typename T, typename... Args>
276   static std::unique_ptr<T> create(Args &&...args) {
277     std::unique_ptr<T> pattern =
278         std::make_unique<T>(std::forward<Args>(args)...);
279     initializePattern<T>(*pattern);
280 
281     // Set a default debug name if one wasn't provided.
282     if (pattern->getDebugName().empty())
283       pattern->setDebugName(llvm::getTypeName<T>());
284     return pattern;
285   }
286 
287 protected:
288   /// Inherit the base constructors from `Pattern`.
289   using Pattern::Pattern;
290 
291 private:
292   /// Trait to check if T provides a `initialize` method.
293   template <typename T, typename... Args>
294   using has_initialize = decltype(std::declval<T>().initialize());
295   template <typename T>
296   using detect_has_initialize = llvm::is_detected<has_initialize, T>;
297 
298   /// Initialize the derived pattern by calling its `initialize` method.
299   template <typename T>
300   static std::enable_if_t<detect_has_initialize<T>::value>
301   initializePattern(T &pattern) {
302     pattern.initialize();
303   }
304   /// Empty derived pattern initializer for patterns that do not have an
305   /// initialize method.
306   template <typename T>
307   static std::enable_if_t<!detect_has_initialize<T>::value>
308   initializePattern(T &) {}
309 
310   /// An anchor for the virtual table.
311   virtual void anchor();
312 };
313 
314 namespace detail {
315 /// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
316 /// allows for matching and rewriting against an instance of a derived operation
317 /// class or Interface.
318 template <typename SourceOp>
319 struct OpOrInterfaceRewritePatternBase : public RewritePattern {
320   using RewritePattern::RewritePattern;
321 
322   /// Wrappers around the RewritePattern methods that pass the derived op type.
323   void rewrite(Operation *op, PatternRewriter &rewriter) const final {
324     rewrite(cast<SourceOp>(op), rewriter);
325   }
326   LogicalResult match(Operation *op) const final {
327     return match(cast<SourceOp>(op));
328   }
329   LogicalResult matchAndRewrite(Operation *op,
330                                 PatternRewriter &rewriter) const final {
331     return matchAndRewrite(cast<SourceOp>(op), rewriter);
332   }
333 
334   /// Rewrite and Match methods that operate on the SourceOp type. These must be
335   /// overridden by the derived pattern class.
336   virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
337     llvm_unreachable("must override rewrite or matchAndRewrite");
338   }
339   virtual LogicalResult match(SourceOp op) const {
340     llvm_unreachable("must override match or matchAndRewrite");
341   }
342   virtual LogicalResult matchAndRewrite(SourceOp op,
343                                         PatternRewriter &rewriter) const {
344     if (succeeded(match(op))) {
345       rewrite(op, rewriter);
346       return success();
347     }
348     return failure();
349   }
350 };
351 } // namespace detail
352 
353 /// OpRewritePattern is a wrapper around RewritePattern that allows for
354 /// matching and rewriting against an instance of a derived operation class as
355 /// opposed to a raw Operation.
356 template <typename SourceOp>
357 struct OpRewritePattern
358     : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
359   /// Patterns must specify the root operation name they match against, and can
360   /// also specify the benefit of the pattern matching and a list of generated
361   /// ops.
362   OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1,
363                    ArrayRef<StringRef> generatedNames = {})
364       : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
365             SourceOp::getOperationName(), benefit, context, generatedNames) {}
366 };
367 
368 /// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
369 /// matching and rewriting against an instance of an operation interface instead
370 /// of a raw Operation.
371 template <typename SourceOp>
372 struct OpInterfaceRewritePattern
373     : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
374   OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
375       : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
376             Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
377             benefit, context) {}
378 };
379 
380 /// OpTraitRewritePattern is a wrapper around RewritePattern that allows for
381 /// matching and rewriting against instances of an operation that possess a
382 /// given trait.
383 template <template <typename> class TraitType>
384 class OpTraitRewritePattern : public RewritePattern {
385 public:
386   OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
387       : RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(),
388                        benefit, context) {}
389 };
390 
391 //===----------------------------------------------------------------------===//
392 // RewriterBase
393 //===----------------------------------------------------------------------===//
394 
395 /// This class coordinates the application of a rewrite on a set of IR,
396 /// providing a way for clients to track mutations and create new operations.
397 /// This class serves as a common API for IR mutation between pattern rewrites
398 /// and non-pattern rewrites, and facilitates the development of shared
399 /// IR transformation utilities.
400 class RewriterBase : public OpBuilder {
401 public:
402   struct Listener : public OpBuilder::Listener {
403     Listener()
404         : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
405 
406     /// Notify the listener that the specified block is about to be erased.
407     /// At this point, the block has zero uses.
408     virtual void notifyBlockErased(Block *block) {}
409 
410     /// Notify the listener that the specified operation was modified in-place.
411     virtual void notifyOperationModified(Operation *op) {}
412 
413     /// Notify the listener that all uses of the specified operation's results
414     /// are about to be replaced with the results of another operation. This is
415     /// called before the uses of the old operation have been changed.
416     ///
417     /// By default, this function calls the "operation replaced with values"
418     /// notification.
419     virtual void notifyOperationReplaced(Operation *op,
420                                          Operation *replacement) {
421       notifyOperationReplaced(op, replacement->getResults());
422     }
423 
424     /// Notify the listener that all uses of the specified operation's results
425     /// are about to be replaced with the a range of values, potentially
426     /// produced by other operations. This is called before the uses of the
427     /// operation have been changed.
428     virtual void notifyOperationReplaced(Operation *op,
429                                          ValueRange replacement) {}
430 
431     /// Notify the listener that the specified operation is about to be erased.
432     /// At this point, the operation has zero uses.
433     ///
434     /// Note: This notification is not triggered when unlinking an operation.
435     virtual void notifyOperationErased(Operation *op) {}
436 
437     /// Notify the listener that the specified pattern is about to be applied
438     /// at the specified root operation.
439     virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {}
440 
441     /// Notify the listener that a pattern application finished with the
442     /// specified status. "success" indicates that the pattern was applied
443     /// successfully. "failure" indicates that the pattern could not be
444     /// applied. The pattern may have communicated the reason for the failure
445     /// with `notifyMatchFailure`.
446     virtual void notifyPatternEnd(const Pattern &pattern,
447                                   LogicalResult status) {}
448 
449     /// Notify the listener that the pattern failed to match, and provide a
450     /// callback to populate a diagnostic with the reason why the failure
451     /// occurred. This method allows for derived listeners to optionally hook
452     /// into the reason why a rewrite failed, and display it to users.
453     virtual void
454     notifyMatchFailure(Location loc,
455                        function_ref<void(Diagnostic &)> reasonCallback) {}
456 
457     static bool classof(const OpBuilder::Listener *base);
458   };
459 
460   /// A listener that forwards all notifications to another listener. This
461   /// struct can be used as a base to create listener chains, so that multiple
462   /// listeners can be notified of IR changes.
463   struct ForwardingListener : public RewriterBase::Listener {
464     ForwardingListener(OpBuilder::Listener *listener)
465         : listener(listener),
466           rewriteListener(
467               dyn_cast_if_present<RewriterBase::Listener>(listener)) {}
468 
469     void notifyOperationInserted(Operation *op, InsertPoint previous) override {
470       if (listener)
471         listener->notifyOperationInserted(op, previous);
472     }
473     void notifyBlockInserted(Block *block, Region *previous,
474                              Region::iterator previousIt) override {
475       if (listener)
476         listener->notifyBlockInserted(block, previous, previousIt);
477     }
478     void notifyBlockErased(Block *block) override {
479       if (rewriteListener)
480         rewriteListener->notifyBlockErased(block);
481     }
482     void notifyOperationModified(Operation *op) override {
483       if (rewriteListener)
484         rewriteListener->notifyOperationModified(op);
485     }
486     void notifyOperationReplaced(Operation *op, Operation *newOp) override {
487       if (rewriteListener)
488         rewriteListener->notifyOperationReplaced(op, newOp);
489     }
490     void notifyOperationReplaced(Operation *op,
491                                  ValueRange replacement) override {
492       if (rewriteListener)
493         rewriteListener->notifyOperationReplaced(op, replacement);
494     }
495     void notifyOperationErased(Operation *op) override {
496       if (rewriteListener)
497         rewriteListener->notifyOperationErased(op);
498     }
499     void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
500       if (rewriteListener)
501         rewriteListener->notifyPatternBegin(pattern, op);
502     }
503     void notifyPatternEnd(const Pattern &pattern,
504                           LogicalResult status) override {
505       if (rewriteListener)
506         rewriteListener->notifyPatternEnd(pattern, status);
507     }
508     void notifyMatchFailure(
509         Location loc,
510         function_ref<void(Diagnostic &)> reasonCallback) override {
511       if (rewriteListener)
512         rewriteListener->notifyMatchFailure(loc, reasonCallback);
513     }
514 
515   private:
516     OpBuilder::Listener *listener;
517     RewriterBase::Listener *rewriteListener;
518   };
519 
520   /// Move the blocks that belong to "region" before the given position in
521   /// another region "parent". The two regions must be different. The caller
522   /// is responsible for creating or updating the operation transferring flow
523   /// of control to the region and passing it the correct block arguments.
524   void inlineRegionBefore(Region &region, Region &parent,
525                           Region::iterator before);
526   void inlineRegionBefore(Region &region, Block *before);
527 
528   /// Replace the results of the given (original) operation with the specified
529   /// list of values (replacements). The result types of the given op and the
530   /// replacements must match. The original op is erased.
531   virtual void replaceOp(Operation *op, ValueRange newValues);
532 
533   /// Replace the results of the given (original) operation with the specified
534   /// new op (replacement). The result types of the two ops must match. The
535   /// original op is erased.
536   virtual void replaceOp(Operation *op, Operation *newOp);
537 
538   /// Replace the results of the given (original) op with a new op that is
539   /// created without verification (replacement). The result values of the two
540   /// ops must match. The original op is erased.
541   template <typename OpTy, typename... Args>
542   OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
543     auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
544     replaceOp(op, newOp.getOperation());
545     return newOp;
546   }
547 
548   /// This method erases an operation that is known to have no uses.
549   virtual void eraseOp(Operation *op);
550 
551   /// This method erases all operations in a block.
552   virtual void eraseBlock(Block *block);
553 
554   /// Inline the operations of block 'source' into block 'dest' before the given
555   /// position. The source block will be deleted and must have no uses.
556   /// 'argValues' is used to replace the block arguments of 'source'.
557   ///
558   /// If the source block is inserted at the end of the dest block, the dest
559   /// block must have no successors. Similarly, if the source block is inserted
560   /// somewhere in the middle (or beginning) of the dest block, the source block
561   /// must have no successors. Otherwise, the resulting IR would have
562   /// unreachable operations.
563   virtual void inlineBlockBefore(Block *source, Block *dest,
564                                  Block::iterator before,
565                                  ValueRange argValues = std::nullopt);
566 
567   /// Inline the operations of block 'source' before the operation 'op'. The
568   /// source block will be deleted and must have no uses. 'argValues' is used to
569   /// replace the block arguments of 'source'
570   ///
571   /// The source block must have no successors. Otherwise, the resulting IR
572   /// would have unreachable operations.
573   void inlineBlockBefore(Block *source, Operation *op,
574                          ValueRange argValues = std::nullopt);
575 
576   /// Inline the operations of block 'source' into the end of block 'dest'. The
577   /// source block will be deleted and must have no uses. 'argValues' is used to
578   /// replace the block arguments of 'source'
579   ///
580   /// The dest block must have no successors. Otherwise, the resulting IR would
581   /// have unreachable operation.
582   void mergeBlocks(Block *source, Block *dest,
583                    ValueRange argValues = std::nullopt);
584 
585   /// Split the operations starting at "before" (inclusive) out of the given
586   /// block into a new block, and return it.
587   Block *splitBlock(Block *block, Block::iterator before);
588 
589   /// Unlink this operation from its current block and insert it right before
590   /// `existingOp` which may be in the same or another block in the same
591   /// function.
592   void moveOpBefore(Operation *op, Operation *existingOp);
593 
594   /// Unlink this operation from its current block and insert it right before
595   /// `iterator` in the specified block.
596   void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
597 
598   /// Unlink this operation from its current block and insert it right after
599   /// `existingOp` which may be in the same or another block in the same
600   /// function.
601   void moveOpAfter(Operation *op, Operation *existingOp);
602 
603   /// Unlink this operation from its current block and insert it right after
604   /// `iterator` in the specified block.
605   void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
606 
607   /// Unlink this block and insert it right before `existingBlock`.
608   void moveBlockBefore(Block *block, Block *anotherBlock);
609 
610   /// Unlink this block and insert it right before the location that the given
611   /// iterator points to in the given region.
612   void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
613 
614   /// This method is used to notify the rewriter that an in-place operation
615   /// modification is about to happen. A call to this function *must* be
616   /// followed by a call to either `finalizeOpModification` or
617   /// `cancelOpModification`. This is a minor efficiency win (it avoids creating
618   /// a new operation and removing the old one) but also often allows simpler
619   /// code in the client.
620   virtual void startOpModification(Operation *op) {}
621 
622   /// This method is used to signal the end of an in-place modification of the
623   /// given operation. This can only be called on operations that were provided
624   /// to a call to `startOpModification`.
625   virtual void finalizeOpModification(Operation *op);
626 
627   /// This method cancels a pending in-place modification. This can only be
628   /// called on operations that were provided to a call to
629   /// `startOpModification`.
630   virtual void cancelOpModification(Operation *op) {}
631 
632   /// This method is a utility wrapper around an in-place modification of an
633   /// operation. It wraps calls to `startOpModification` and
634   /// `finalizeOpModification` around the given callable.
635   template <typename CallableT>
636   void modifyOpInPlace(Operation *root, CallableT &&callable) {
637     startOpModification(root);
638     callable();
639     finalizeOpModification(root);
640   }
641 
642   /// Find uses of `from` and replace them with `to`. Also notify the listener
643   /// about every in-place op modification (for every use that was replaced).
644   void replaceAllUsesWith(Value from, Value to) {
645     for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
646       Operation *op = operand.getOwner();
647       modifyOpInPlace(op, [&]() { operand.set(to); });
648     }
649   }
650   void replaceAllUsesWith(Block *from, Block *to) {
651     for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) {
652       Operation *op = operand.getOwner();
653       modifyOpInPlace(op, [&]() { operand.set(to); });
654     }
655   }
656   void replaceAllUsesWith(ValueRange from, ValueRange to) {
657     assert(from.size() == to.size() && "incorrect number of replacements");
658     for (auto it : llvm::zip(from, to))
659       replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
660   }
661 
662   /// Find uses of `from` and replace them with `to`. Also notify the listener
663   /// about every in-place op modification (for every use that was replaced)
664   /// and that the `from` operation is about to be replaced.
665   ///
666   /// Note: This function cannot be called `replaceAllUsesWith` because the
667   /// overload resolution, when called with an op that can be implicitly
668   /// converted to a Value, would be ambiguous.
669   void replaceAllOpUsesWith(Operation *from, ValueRange to);
670   void replaceAllOpUsesWith(Operation *from, Operation *to);
671 
672   /// Find uses of `from` and replace them with `to` if the `functor` returns
673   /// true. Also notify the listener about every in-place op modification (for
674   /// every use that was replaced). The optional `allUsesReplaced` flag is set
675   /// to "true" if all uses were replaced.
676   void replaceUsesWithIf(Value from, Value to,
677                          function_ref<bool(OpOperand &)> functor,
678                          bool *allUsesReplaced = nullptr);
679   void replaceUsesWithIf(ValueRange from, ValueRange to,
680                          function_ref<bool(OpOperand &)> functor,
681                          bool *allUsesReplaced = nullptr);
682   // Note: This function cannot be called `replaceOpUsesWithIf` because the
683   // overload resolution, when called with an op that can be implicitly
684   // converted to a Value, would be ambiguous.
685   void replaceOpUsesWithIf(Operation *from, ValueRange to,
686                            function_ref<bool(OpOperand &)> functor,
687                            bool *allUsesReplaced = nullptr) {
688     replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced);
689   }
690 
691   /// Find uses of `from` within `block` and replace them with `to`. Also notify
692   /// the listener about every in-place op modification (for every use that was
693   /// replaced). The optional `allUsesReplaced` flag is set to "true" if all
694   /// uses were replaced.
695   void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues,
696                                 Block *block, bool *allUsesReplaced = nullptr) {
697     replaceOpUsesWithIf(
698         op, newValues,
699         [block](OpOperand &use) {
700           return block->getParentOp()->isProperAncestor(use.getOwner());
701         },
702         allUsesReplaced);
703   }
704 
705   /// Find uses of `from` and replace them with `to` except if the user is
706   /// `exceptedUser`. Also notify the listener about every in-place op
707   /// modification (for every use that was replaced).
708   void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
709     return replaceUsesWithIf(from, to, [&](OpOperand &use) {
710       Operation *user = use.getOwner();
711       return user != exceptedUser;
712     });
713   }
714   void replaceAllUsesExcept(Value from, Value to,
715                             const SmallPtrSetImpl<Operation *> &preservedUsers);
716 
717   /// Used to notify the listener that the IR failed to be rewritten because of
718   /// a match failure, and provide a callback to populate a diagnostic with the
719   /// reason why the failure occurred. This method allows for derived rewriters
720   /// to optionally hook into the reason why a rewrite failed, and display it to
721   /// users.
722   template <typename CallbackT>
723   std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
724   notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
725     if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
726       rewriteListener->notifyMatchFailure(
727           loc, function_ref<void(Diagnostic &)>(reasonCallback));
728     return failure();
729   }
730   template <typename CallbackT>
731   std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
732   notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
733     if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
734       rewriteListener->notifyMatchFailure(
735           op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback));
736     return failure();
737   }
738   template <typename ArgT>
739   LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
740     return notifyMatchFailure(std::forward<ArgT>(arg),
741                               [&](Diagnostic &diag) { diag << msg; });
742   }
743   template <typename ArgT>
744   LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) {
745     return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
746   }
747 
748 protected:
749   /// Initialize the builder.
750   explicit RewriterBase(MLIRContext *ctx,
751                         OpBuilder::Listener *listener = nullptr)
752       : OpBuilder(ctx, listener) {}
753   explicit RewriterBase(const OpBuilder &otherBuilder)
754       : OpBuilder(otherBuilder) {}
755   explicit RewriterBase(Operation *op, OpBuilder::Listener *listener = nullptr)
756       : OpBuilder(op, listener) {}
757   virtual ~RewriterBase();
758 
759 private:
760   void operator=(const RewriterBase &) = delete;
761   RewriterBase(const RewriterBase &) = delete;
762 };
763 
764 //===----------------------------------------------------------------------===//
765 // IRRewriter
766 //===----------------------------------------------------------------------===//
767 
768 /// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
769 /// providing a way to keep track of the mutations made to the IR. This class
770 /// should only be used in situations where another `RewriterBase` instance,
771 /// such as a `PatternRewriter`, is not available.
772 class IRRewriter : public RewriterBase {
773 public:
774   explicit IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr)
775       : RewriterBase(ctx, listener) {}
776   explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
777   explicit IRRewriter(Operation *op, OpBuilder::Listener *listener = nullptr)
778       : RewriterBase(op, listener) {}
779 };
780 
781 //===----------------------------------------------------------------------===//
782 // PatternRewriter
783 //===----------------------------------------------------------------------===//
784 
785 /// A special type of `RewriterBase` that coordinates the application of a
786 /// rewrite pattern on the current IR being matched, providing a way to keep
787 /// track of any mutations made. This class should be used to perform all
788 /// necessary IR mutations within a rewrite pattern, as the pattern driver may
789 /// be tracking various state that would be invalidated when a mutation takes
790 /// place.
791 class PatternRewriter : public RewriterBase {
792 public:
793   explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
794   using RewriterBase::RewriterBase;
795 
796   /// A hook used to indicate if the pattern rewriter can recover from failure
797   /// during the rewrite stage of a pattern. For example, if the pattern
798   /// rewriter supports rollback, it may progress smoothly even if IR was
799   /// changed during the rewrite.
800   virtual bool canRecoverFromRewriteFailure() const { return false; }
801 };
802 
803 } // namespace mlir
804 
805 // Optionally expose PDL pattern matching methods.
806 #include "PDLPatternMatch.h.inc"
807 
808 namespace mlir {
809 
810 //===----------------------------------------------------------------------===//
811 // RewritePatternSet
812 //===----------------------------------------------------------------------===//
813 
814 class RewritePatternSet {
815   using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
816 
817 public:
818   RewritePatternSet(MLIRContext *context) : context(context) {}
819 
820   /// Construct a RewritePatternSet populated with the given pattern.
821   RewritePatternSet(MLIRContext *context,
822                     std::unique_ptr<RewritePattern> pattern)
823       : context(context) {
824     nativePatterns.emplace_back(std::move(pattern));
825   }
826   RewritePatternSet(PDLPatternModule &&pattern)
827       : context(pattern.getContext()), pdlPatterns(std::move(pattern)) {}
828 
829   MLIRContext *getContext() const { return context; }
830 
831   /// Return the native patterns held in this list.
832   NativePatternListT &getNativePatterns() { return nativePatterns; }
833 
834   /// Return the PDL patterns held in this list.
835   PDLPatternModule &getPDLPatterns() { return pdlPatterns; }
836 
837   /// Clear out all of the held patterns in this list.
838   void clear() {
839     nativePatterns.clear();
840     pdlPatterns.clear();
841   }
842 
843   //===--------------------------------------------------------------------===//
844   // 'add' methods for adding patterns to the set.
845   //===--------------------------------------------------------------------===//
846 
847   /// Add an instance of each of the pattern types 'Ts' to the pattern list with
848   /// the given arguments. Return a reference to `this` for chaining insertions.
849   /// Note: ConstructorArg is necessary here to separate the two variadic lists.
850   template <typename... Ts, typename ConstructorArg,
851             typename... ConstructorArgs,
852             typename = std::enable_if_t<sizeof...(Ts) != 0>>
853   RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) {
854     // The following expands a call to emplace_back for each of the pattern
855     // types 'Ts'.
856     (addImpl<Ts>(/*debugLabels=*/std::nullopt,
857                  std::forward<ConstructorArg>(arg),
858                  std::forward<ConstructorArgs>(args)...),
859      ...);
860     return *this;
861   }
862   /// An overload of the above `add` method that allows for attaching a set
863   /// of debug labels to the attached patterns. This is useful for labeling
864   /// groups of patterns that may be shared between multiple different
865   /// passes/users.
866   template <typename... Ts, typename ConstructorArg,
867             typename... ConstructorArgs,
868             typename = std::enable_if_t<sizeof...(Ts) != 0>>
869   RewritePatternSet &addWithLabel(ArrayRef<StringRef> debugLabels,
870                                   ConstructorArg &&arg,
871                                   ConstructorArgs &&...args) {
872     // The following expands a call to emplace_back for each of the pattern
873     // types 'Ts'.
874     (addImpl<Ts>(debugLabels, arg, args...), ...);
875     return *this;
876   }
877 
878   /// Add an instance of each of the pattern types 'Ts'. Return a reference to
879   /// `this` for chaining insertions.
880   template <typename... Ts>
881   RewritePatternSet &add() {
882     (addImpl<Ts>(), ...);
883     return *this;
884   }
885 
886   /// Add the given native pattern to the pattern list. Return a reference to
887   /// `this` for chaining insertions.
888   RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) {
889     nativePatterns.emplace_back(std::move(pattern));
890     return *this;
891   }
892 
893   /// Add the given PDL pattern to the pattern list. Return a reference to
894   /// `this` for chaining insertions.
895   RewritePatternSet &add(PDLPatternModule &&pattern) {
896     pdlPatterns.mergeIn(std::move(pattern));
897     return *this;
898   }
899 
900   // Add a matchAndRewrite style pattern represented as a C function pointer.
901   template <typename OpType>
902   RewritePatternSet &
903   add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
904       PatternBenefit benefit = 1, ArrayRef<StringRef> generatedNames = {}) {
905     struct FnPattern final : public OpRewritePattern<OpType> {
906       FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
907                 MLIRContext *context, PatternBenefit benefit,
908                 ArrayRef<StringRef> generatedNames)
909           : OpRewritePattern<OpType>(context, benefit, generatedNames),
910             implFn(implFn) {}
911 
912       LogicalResult matchAndRewrite(OpType op,
913                                     PatternRewriter &rewriter) const override {
914         return implFn(op, rewriter);
915       }
916 
917     private:
918       LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
919     };
920     add(std::make_unique<FnPattern>(std::move(implFn), getContext(), benefit,
921                                     generatedNames));
922     return *this;
923   }
924 
925   //===--------------------------------------------------------------------===//
926   // Pattern Insertion
927   //===--------------------------------------------------------------------===//
928 
929   // TODO: These are soft deprecated in favor of the 'add' methods above.
930 
931   /// Add an instance of each of the pattern types 'Ts' to the pattern list with
932   /// the given arguments. Return a reference to `this` for chaining insertions.
933   /// Note: ConstructorArg is necessary here to separate the two variadic lists.
934   template <typename... Ts, typename ConstructorArg,
935             typename... ConstructorArgs,
936             typename = std::enable_if_t<sizeof...(Ts) != 0>>
937   RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) {
938     // The following expands a call to emplace_back for each of the pattern
939     // types 'Ts'.
940     (addImpl<Ts>(/*debugLabels=*/std::nullopt, arg, args...), ...);
941     return *this;
942   }
943 
944   /// Add an instance of each of the pattern types 'Ts'. Return a reference to
945   /// `this` for chaining insertions.
946   template <typename... Ts>
947   RewritePatternSet &insert() {
948     (addImpl<Ts>(), ...);
949     return *this;
950   }
951 
952   /// Add the given native pattern to the pattern list. Return a reference to
953   /// `this` for chaining insertions.
954   RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) {
955     nativePatterns.emplace_back(std::move(pattern));
956     return *this;
957   }
958 
959   /// Add the given PDL pattern to the pattern list. Return a reference to
960   /// `this` for chaining insertions.
961   RewritePatternSet &insert(PDLPatternModule &&pattern) {
962     pdlPatterns.mergeIn(std::move(pattern));
963     return *this;
964   }
965 
966   // Add a matchAndRewrite style pattern represented as a C function pointer.
967   template <typename OpType>
968   RewritePatternSet &
969   insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
970     struct FnPattern final : public OpRewritePattern<OpType> {
971       FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
972                 MLIRContext *context)
973           : OpRewritePattern<OpType>(context), implFn(implFn) {
974         this->setDebugName(llvm::getTypeName<FnPattern>());
975       }
976 
977       LogicalResult matchAndRewrite(OpType op,
978                                     PatternRewriter &rewriter) const override {
979         return implFn(op, rewriter);
980       }
981 
982     private:
983       LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
984     };
985     add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
986     return *this;
987   }
988 
989 private:
990   /// Add an instance of the pattern type 'T'. Return a reference to `this` for
991   /// chaining insertions.
992   template <typename T, typename... Args>
993   std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
994   addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
995     std::unique_ptr<T> pattern =
996         RewritePattern::create<T>(std::forward<Args>(args)...);
997     pattern->addDebugLabels(debugLabels);
998     nativePatterns.emplace_back(std::move(pattern));
999   }
1000 
1001   template <typename T, typename... Args>
1002   std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
1003   addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
1004     // TODO: Add the provided labels to the PDL pattern when PDL supports
1005     // labels.
1006     pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
1007   }
1008 
1009   MLIRContext *const context;
1010   NativePatternListT nativePatterns;
1011 
1012   // Patterns expressed with PDL. This will compile to a stub class when PDL is
1013   // not enabled.
1014   PDLPatternModule pdlPatterns;
1015 };
1016 
1017 } // namespace mlir
1018 
1019 #endif // MLIR_IR_PATTERNMATCH_H
1020