xref: /llvm-project/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h (revision 1a36588ec64ae8576e531e6f0b49eadb90ab0b11)
1 //===- FrozenRewritePatternSet.h --------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_REWRITE_FROZENREWRITEPATTERNSET_H
10 #define MLIR_REWRITE_FROZENREWRITEPATTERNSET_H
11 
12 #include "mlir/IR/PatternMatch.h"
13 
14 namespace mlir {
15 namespace detail {
16 class PDLByteCode;
17 } // namespace detail
18 
19 /// This class represents a frozen set of patterns that can be processed by a
20 /// pattern applicator. This class is designed to enable caching pattern lists
21 /// such that they need not be continuously recomputed. Note that all copies of
22 /// this class share the same compiled pattern list, allowing for a reduction in
23 /// the number of duplicated patterns that need to be created.
24 class FrozenRewritePatternSet {
25   using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
26 
27 public:
28   /// A map of operation specific native patterns.
29   using OpSpecificNativePatternListT =
30       DenseMap<OperationName, std::vector<RewritePattern *>>;
31 
32   FrozenRewritePatternSet();
33   FrozenRewritePatternSet(FrozenRewritePatternSet &&patterns) = default;
34   FrozenRewritePatternSet(const FrozenRewritePatternSet &patterns) = default;
35   FrozenRewritePatternSet &
36   operator=(const FrozenRewritePatternSet &patterns) = default;
37   FrozenRewritePatternSet &
38   operator=(FrozenRewritePatternSet &&patterns) = default;
39   ~FrozenRewritePatternSet();
40 
41   /// Freeze the patterns held in `patterns`, and take ownership.
42   /// `disabledPatternLabels` is a set of labels used to filter out input
43   /// patterns with a debug label or debug name in this set.
44   /// `enabledPatternLabels` is a set of labels used to filter out input
45   /// patterns that do not have one of the labels in this set. Debug labels must
46   /// be set explicitly on patterns or when adding them with
47   /// `RewritePatternSet::addWithLabel`. Debug names may be empty, but patterns
48   /// created with `RewritePattern::create` have their default debug name set to
49   /// their type name.
50   FrozenRewritePatternSet(
51       RewritePatternSet &&patterns,
52       ArrayRef<std::string> disabledPatternLabels = std::nullopt,
53       ArrayRef<std::string> enabledPatternLabels = std::nullopt);
54 
55   /// Return the op specific native patterns held by this list.
getOpSpecificNativePatterns()56   const OpSpecificNativePatternListT &getOpSpecificNativePatterns() const {
57     return impl->nativeOpSpecificPatternMap;
58   }
59 
60   /// Return the "match any" native patterns held by this list.
61   iterator_range<llvm::pointee_iterator<NativePatternListT::const_iterator>>
getMatchAnyOpNativePatterns()62   getMatchAnyOpNativePatterns() const {
63     const NativePatternListT &nativeList = impl->nativeAnyOpPatterns;
64     return llvm::make_pointee_range(nativeList);
65   }
66 
67   /// Return the compiled PDL bytecode held by this list. Returns null if
68   /// there are no PDL patterns within the list.
getPDLByteCode()69   const detail::PDLByteCode *getPDLByteCode() const {
70     return impl->pdlByteCode.get();
71   }
72 
73 private:
74   /// The internal implementation of the frozen pattern list.
75   struct Impl {
76     /// The set of native C++ rewrite patterns that are matched to specific
77     /// operation kinds.
78     OpSpecificNativePatternListT nativeOpSpecificPatternMap;
79 
80     /// The full op-specific native rewrite list. This allows for the map above
81     /// to contain duplicate patterns, e.g. for interfaces and traits.
82     NativePatternListT nativeOpSpecificPatternList;
83 
84     /// The set of native C++ rewrite patterns that are matched to "any"
85     /// operation.
86     NativePatternListT nativeAnyOpPatterns;
87 
88     /// The bytecode containing the compiled PDL patterns.
89     std::unique_ptr<detail::PDLByteCode> pdlByteCode;
90   };
91 
92   /// A pointer to the internal pattern list. This uses a shared_ptr to avoid
93   /// the need to compile the same pattern list multiple times. For example,
94   /// during multi-threaded pass execution, all copies of a pass can share the
95   /// same pattern list.
96   std::shared_ptr<Impl> impl;
97 };
98 
99 } // namespace mlir
100 
101 #endif // MLIR_REWRITE_FROZENREWRITEPATTERNSET_H
102