xref: /llvm-project/mlir/include/mlir/Rewrite/PatternApplicator.h (revision 1020150e7a6f6d6f833c232125c5ab817c03c76b)
1 //===- PatternApplicator.h - PatternApplicator ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an applicator that applies pattern rewrites based upon a
10 // user defined cost model.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H
15 #define MLIR_REWRITE_PATTERNAPPLICATOR_H
16 
17 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
18 
19 #include "mlir/IR/Action.h"
20 
21 namespace mlir {
22 class PatternRewriter;
23 
24 namespace detail {
25 class PDLByteCodeMutableState;
26 } // namespace detail
27 
28 /// This is the type of Action that is dispatched when a pattern is applied.
29 /// It captures the pattern to apply on top of the usual context.
30 class ApplyPatternAction : public tracing::ActionImpl<ApplyPatternAction> {
31 public:
32   using Base = tracing::ActionImpl<ApplyPatternAction>;
ApplyPatternAction(ArrayRef<IRUnit> irUnits,const Pattern & pattern)33   ApplyPatternAction(ArrayRef<IRUnit> irUnits, const Pattern &pattern)
34       : Base(irUnits), pattern(pattern) {}
35   static constexpr StringLiteral tag = "apply-pattern";
36   static constexpr StringLiteral desc =
37       "Encapsulate the application of rewrite patterns";
38 
print(raw_ostream & os)39   void print(raw_ostream &os) const override {
40     os << "`" << tag << " pattern: " << pattern.getDebugName();
41   }
42 
43 private:
44   const Pattern &pattern;
45 };
46 
47 /// This class manages the application of a group of rewrite patterns, with a
48 /// user-provided cost model.
49 class PatternApplicator {
50 public:
51   /// The cost model dynamically assigns a PatternBenefit to a particular
52   /// pattern. Users can query contained patterns and pass analysis results to
53   /// applyCostModel. Patterns to be discarded should have a benefit of
54   /// `impossibleToMatch`.
55   using CostModel = function_ref<PatternBenefit(const Pattern &)>;
56 
57   explicit PatternApplicator(const FrozenRewritePatternSet &frozenPatternList);
58   ~PatternApplicator();
59 
60   /// Attempt to match and rewrite the given op with any pattern, allowing a
61   /// predicate to decide if a pattern can be applied or not, and hooks for if
62   /// the pattern match was a success or failure.
63   ///
64   /// canApply:  called before each match and rewrite attempt; return false to
65   ///            skip pattern.
66   /// onFailure: called when a pattern fails to match to perform cleanup.
67   /// onSuccess: called when a pattern match succeeds; return failure() to
68   ///            invalidate the match and try another pattern.
69   LogicalResult
70   matchAndRewrite(Operation *op, PatternRewriter &rewriter,
71                   function_ref<bool(const Pattern &)> canApply = {},
72                   function_ref<void(const Pattern &)> onFailure = {},
73                   function_ref<LogicalResult(const Pattern &)> onSuccess = {});
74 
75   /// Apply a cost model to the patterns within this applicator.
76   void applyCostModel(CostModel model);
77 
78   /// Apply the default cost model that solely uses the pattern's static
79   /// benefit.
applyDefaultCostModel()80   void applyDefaultCostModel() {
81     applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
82   }
83 
84   /// Walk all of the patterns within the applicator.
85   void walkAllPatterns(function_ref<void(const Pattern &)> walk);
86 
87 private:
88   /// The list that owns the patterns used within this applicator.
89   const FrozenRewritePatternSet &frozenPatternList;
90   /// The set of patterns to match for each operation, stable sorted by benefit.
91   DenseMap<OperationName, SmallVector<const RewritePattern *, 2>> patterns;
92   /// The set of patterns that may match against any operation type, stable
93   /// sorted by benefit.
94   SmallVector<const RewritePattern *, 1> anyOpPatterns;
95   /// The mutable state used during execution of the PDL bytecode.
96   std::unique_ptr<detail::PDLByteCodeMutableState> mutableByteCodeState;
97 };
98 
99 } // namespace mlir
100 
101 #endif // MLIR_REWRITE_PATTERNAPPLICATOR_H
102