1 //===- PatternApplicator.h - PatternApplicator ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements an applicator that applies pattern rewrites based upon a 10 // user defined cost model. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H 15 #define MLIR_REWRITE_PATTERNAPPLICATOR_H 16 17 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 18 19 #include "mlir/IR/Action.h" 20 21 namespace mlir { 22 class PatternRewriter; 23 24 namespace detail { 25 class PDLByteCodeMutableState; 26 } // namespace detail 27 28 /// This is the type of Action that is dispatched when a pattern is applied. 29 /// It captures the pattern to apply on top of the usual context. 30 class ApplyPatternAction : public tracing::ActionImpl<ApplyPatternAction> { 31 public: 32 using Base = tracing::ActionImpl<ApplyPatternAction>; ApplyPatternAction(ArrayRef<IRUnit> irUnits,const Pattern & pattern)33 ApplyPatternAction(ArrayRef<IRUnit> irUnits, const Pattern &pattern) 34 : Base(irUnits), pattern(pattern) {} 35 static constexpr StringLiteral tag = "apply-pattern"; 36 static constexpr StringLiteral desc = 37 "Encapsulate the application of rewrite patterns"; 38 print(raw_ostream & os)39 void print(raw_ostream &os) const override { 40 os << "`" << tag << " pattern: " << pattern.getDebugName(); 41 } 42 43 private: 44 const Pattern &pattern; 45 }; 46 47 /// This class manages the application of a group of rewrite patterns, with a 48 /// user-provided cost model. 49 class PatternApplicator { 50 public: 51 /// The cost model dynamically assigns a PatternBenefit to a particular 52 /// pattern. Users can query contained patterns and pass analysis results to 53 /// applyCostModel. Patterns to be discarded should have a benefit of 54 /// `impossibleToMatch`. 55 using CostModel = function_ref<PatternBenefit(const Pattern &)>; 56 57 explicit PatternApplicator(const FrozenRewritePatternSet &frozenPatternList); 58 ~PatternApplicator(); 59 60 /// Attempt to match and rewrite the given op with any pattern, allowing a 61 /// predicate to decide if a pattern can be applied or not, and hooks for if 62 /// the pattern match was a success or failure. 63 /// 64 /// canApply: called before each match and rewrite attempt; return false to 65 /// skip pattern. 66 /// onFailure: called when a pattern fails to match to perform cleanup. 67 /// onSuccess: called when a pattern match succeeds; return failure() to 68 /// invalidate the match and try another pattern. 69 LogicalResult 70 matchAndRewrite(Operation *op, PatternRewriter &rewriter, 71 function_ref<bool(const Pattern &)> canApply = {}, 72 function_ref<void(const Pattern &)> onFailure = {}, 73 function_ref<LogicalResult(const Pattern &)> onSuccess = {}); 74 75 /// Apply a cost model to the patterns within this applicator. 76 void applyCostModel(CostModel model); 77 78 /// Apply the default cost model that solely uses the pattern's static 79 /// benefit. applyDefaultCostModel()80 void applyDefaultCostModel() { 81 applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); }); 82 } 83 84 /// Walk all of the patterns within the applicator. 85 void walkAllPatterns(function_ref<void(const Pattern &)> walk); 86 87 private: 88 /// The list that owns the patterns used within this applicator. 89 const FrozenRewritePatternSet &frozenPatternList; 90 /// The set of patterns to match for each operation, stable sorted by benefit. 91 DenseMap<OperationName, SmallVector<const RewritePattern *, 2>> patterns; 92 /// The set of patterns that may match against any operation type, stable 93 /// sorted by benefit. 94 SmallVector<const RewritePattern *, 1> anyOpPatterns; 95 /// The mutable state used during execution of the PDL bytecode. 96 std::unique_ptr<detail::PDLByteCodeMutableState> mutableByteCodeState; 97 }; 98 99 } // namespace mlir 100 101 #endif // MLIR_REWRITE_PATTERNAPPLICATOR_H 102