xref: /llvm-project/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp (revision 6ae7f66ff5169ddc5a7b9ab545707042c77e036c)
1 //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
10 #include "ByteCode.h"
11 #include "mlir/Interfaces/SideEffectInterfaces.h"
12 #include "mlir/Pass/Pass.h"
13 #include "mlir/Pass/PassManager.h"
14 #include <optional>
15 
16 using namespace mlir;
17 
18 // Include the PDL rewrite support.
19 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
20 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
21 #include "mlir/Dialect/PDL/IR/PDLOps.h"
22 
23 static LogicalResult
convertPDLToPDLInterp(ModuleOp pdlModule,DenseMap<Operation *,PDLPatternConfigSet * > & configMap)24 convertPDLToPDLInterp(ModuleOp pdlModule,
25                       DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
26   // Skip the conversion if the module doesn't contain pdl.
27   if (pdlModule.getOps<pdl::PatternOp>().empty())
28     return success();
29 
30   // Simplify the provided PDL module. Note that we can't use the canonicalizer
31   // here because it would create a cyclic dependency.
32   auto simplifyFn = [](Operation *op) {
33     // TODO: Add folding here if ever necessary.
34     if (isOpTriviallyDead(op))
35       op->erase();
36   };
37   pdlModule.getBody()->walk(simplifyFn);
38 
39   /// Lower the PDL pattern module to the interpreter dialect.
40   PassManager pdlPipeline(pdlModule->getName());
41 #ifdef NDEBUG
42   // We don't want to incur the hit of running the verifier when in release
43   // mode.
44   pdlPipeline.enableVerifier(false);
45 #endif
46   pdlPipeline.addPass(createPDLToPDLInterpPass(configMap));
47   if (failed(pdlPipeline.run(pdlModule)))
48     return failure();
49 
50   // Simplify again after running the lowering pipeline.
51   pdlModule.getBody()->walk(simplifyFn);
52   return success();
53 }
54 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
55 
56 //===----------------------------------------------------------------------===//
57 // FrozenRewritePatternSet
58 //===----------------------------------------------------------------------===//
59 
FrozenRewritePatternSet()60 FrozenRewritePatternSet::FrozenRewritePatternSet()
61     : impl(std::make_shared<Impl>()) {}
62 
FrozenRewritePatternSet(RewritePatternSet && patterns,ArrayRef<std::string> disabledPatternLabels,ArrayRef<std::string> enabledPatternLabels)63 FrozenRewritePatternSet::FrozenRewritePatternSet(
64     RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
65     ArrayRef<std::string> enabledPatternLabels)
66     : impl(std::make_shared<Impl>()) {
67   DenseSet<StringRef> disabledPatterns, enabledPatterns;
68   disabledPatterns.insert(disabledPatternLabels.begin(),
69                           disabledPatternLabels.end());
70   enabledPatterns.insert(enabledPatternLabels.begin(),
71                          enabledPatternLabels.end());
72 
73   // Functor used to walk all of the operations registered in the context. This
74   // is useful for patterns that get applied to multiple operations, such as
75   // interface and trait based patterns.
76   std::vector<RegisteredOperationName> opInfos;
77   auto addToOpsWhen =
78       [&](std::unique_ptr<RewritePattern> &pattern,
79           function_ref<bool(RegisteredOperationName)> callbackFn) {
80         if (opInfos.empty())
81           opInfos = pattern->getContext()->getRegisteredOperations();
82         for (RegisteredOperationName info : opInfos)
83           if (callbackFn(info))
84             impl->nativeOpSpecificPatternMap[info].push_back(pattern.get());
85         impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
86       };
87 
88   for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
89     // Don't add patterns that haven't been enabled by the user.
90     if (!enabledPatterns.empty()) {
91       auto isEnabledFn = [&](StringRef label) {
92         return enabledPatterns.count(label);
93       };
94       if (!isEnabledFn(pat->getDebugName()) &&
95           llvm::none_of(pat->getDebugLabels(), isEnabledFn))
96         continue;
97     }
98     // Don't add patterns that have been disabled by the user.
99     if (!disabledPatterns.empty()) {
100       auto isDisabledFn = [&](StringRef label) {
101         return disabledPatterns.count(label);
102       };
103       if (isDisabledFn(pat->getDebugName()) ||
104           llvm::any_of(pat->getDebugLabels(), isDisabledFn))
105         continue;
106     }
107 
108     if (std::optional<OperationName> rootName = pat->getRootKind()) {
109       impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
110       impl->nativeOpSpecificPatternList.push_back(std::move(pat));
111       continue;
112     }
113     if (std::optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
114       addToOpsWhen(pat, [&](RegisteredOperationName info) {
115         return info.hasInterface(*interfaceID);
116       });
117       continue;
118     }
119     if (std::optional<TypeID> traitID = pat->getRootTraitID()) {
120       addToOpsWhen(pat, [&](RegisteredOperationName info) {
121         return info.hasTrait(*traitID);
122       });
123       continue;
124     }
125     impl->nativeAnyOpPatterns.push_back(std::move(pat));
126   }
127 
128 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
129   // Generate the bytecode for the PDL patterns if any were provided.
130   PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
131   ModuleOp pdlModule = pdlPatterns.getModule();
132   if (!pdlModule)
133     return;
134   DenseMap<Operation *, PDLPatternConfigSet *> configMap =
135       pdlPatterns.takeConfigMap();
136   if (failed(convertPDLToPDLInterp(pdlModule, configMap)))
137     llvm::report_fatal_error(
138         "failed to lower PDL pattern module to the PDL Interpreter");
139 
140   // Generate the pdl bytecode.
141   impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
142       pdlModule, pdlPatterns.takeConfigs(), configMap,
143       pdlPatterns.takeConstraintFunctions(),
144       pdlPatterns.takeRewriteFunctions());
145 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
146 }
147 
148 FrozenRewritePatternSet::~FrozenRewritePatternSet() = default;
149