1 //===- FrozenRewritePatternSet.h --------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_REWRITE_FROZENREWRITEPATTERNSET_H 10 #define MLIR_REWRITE_FROZENREWRITEPATTERNSET_H 11 12 #include "mlir/IR/PatternMatch.h" 13 14 namespace mlir { 15 namespace detail { 16 class PDLByteCode; 17 } // namespace detail 18 19 /// This class represents a frozen set of patterns that can be processed by a 20 /// pattern applicator. This class is designed to enable caching pattern lists 21 /// such that they need not be continuously recomputed. Note that all copies of 22 /// this class share the same compiled pattern list, allowing for a reduction in 23 /// the number of duplicated patterns that need to be created. 24 class FrozenRewritePatternSet { 25 using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>; 26 27 public: 28 /// A map of operation specific native patterns. 29 using OpSpecificNativePatternListT = 30 DenseMap<OperationName, std::vector<RewritePattern *>>; 31 32 FrozenRewritePatternSet(); 33 FrozenRewritePatternSet(FrozenRewritePatternSet &&patterns) = default; 34 FrozenRewritePatternSet(const FrozenRewritePatternSet &patterns) = default; 35 FrozenRewritePatternSet & 36 operator=(const FrozenRewritePatternSet &patterns) = default; 37 FrozenRewritePatternSet & 38 operator=(FrozenRewritePatternSet &&patterns) = default; 39 ~FrozenRewritePatternSet(); 40 41 /// Freeze the patterns held in `patterns`, and take ownership. 42 /// `disabledPatternLabels` is a set of labels used to filter out input 43 /// patterns with a debug label or debug name in this set. 44 /// `enabledPatternLabels` is a set of labels used to filter out input 45 /// patterns that do not have one of the labels in this set. Debug labels must 46 /// be set explicitly on patterns or when adding them with 47 /// `RewritePatternSet::addWithLabel`. Debug names may be empty, but patterns 48 /// created with `RewritePattern::create` have their default debug name set to 49 /// their type name. 50 FrozenRewritePatternSet( 51 RewritePatternSet &&patterns, 52 ArrayRef<std::string> disabledPatternLabels = std::nullopt, 53 ArrayRef<std::string> enabledPatternLabels = std::nullopt); 54 55 /// Return the op specific native patterns held by this list. getOpSpecificNativePatterns()56 const OpSpecificNativePatternListT &getOpSpecificNativePatterns() const { 57 return impl->nativeOpSpecificPatternMap; 58 } 59 60 /// Return the "match any" native patterns held by this list. 61 iterator_range<llvm::pointee_iterator<NativePatternListT::const_iterator>> getMatchAnyOpNativePatterns()62 getMatchAnyOpNativePatterns() const { 63 const NativePatternListT &nativeList = impl->nativeAnyOpPatterns; 64 return llvm::make_pointee_range(nativeList); 65 } 66 67 /// Return the compiled PDL bytecode held by this list. Returns null if 68 /// there are no PDL patterns within the list. getPDLByteCode()69 const detail::PDLByteCode *getPDLByteCode() const { 70 return impl->pdlByteCode.get(); 71 } 72 73 private: 74 /// The internal implementation of the frozen pattern list. 75 struct Impl { 76 /// The set of native C++ rewrite patterns that are matched to specific 77 /// operation kinds. 78 OpSpecificNativePatternListT nativeOpSpecificPatternMap; 79 80 /// The full op-specific native rewrite list. This allows for the map above 81 /// to contain duplicate patterns, e.g. for interfaces and traits. 82 NativePatternListT nativeOpSpecificPatternList; 83 84 /// The set of native C++ rewrite patterns that are matched to "any" 85 /// operation. 86 NativePatternListT nativeAnyOpPatterns; 87 88 /// The bytecode containing the compiled PDL patterns. 89 std::unique_ptr<detail::PDLByteCode> pdlByteCode; 90 }; 91 92 /// A pointer to the internal pattern list. This uses a shared_ptr to avoid 93 /// the need to compile the same pattern list multiple times. For example, 94 /// during multi-threaded pass execution, all copies of a pass can share the 95 /// same pattern list. 96 std::shared_ptr<Impl> impl; 97 }; 98 99 } // namespace mlir 100 101 #endif // MLIR_REWRITE_FROZENREWRITEPATTERNSET_H 102