xref: /llvm-project/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp (revision 6ae7f66ff5169ddc5a7b9ab545707042c77e036c)
179d7f618SChris Lattner //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
279d7f618SChris Lattner //
379d7f618SChris Lattner // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
479d7f618SChris Lattner // See https://llvm.org/LICENSE.txt for license information.
579d7f618SChris Lattner // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
679d7f618SChris Lattner //
779d7f618SChris Lattner //===----------------------------------------------------------------------===//
879d7f618SChris Lattner 
979d7f618SChris Lattner #include "mlir/Rewrite/FrozenRewritePatternSet.h"
1079d7f618SChris Lattner #include "ByteCode.h"
1179d7f618SChris Lattner #include "mlir/Interfaces/SideEffectInterfaces.h"
1279d7f618SChris Lattner #include "mlir/Pass/Pass.h"
1379d7f618SChris Lattner #include "mlir/Pass/PassManager.h"
14a1fe1f5fSKazu Hirata #include <optional>
1579d7f618SChris Lattner 
1679d7f618SChris Lattner using namespace mlir;
1779d7f618SChris Lattner 
18*6ae7f66fSJacques Pienaar // Include the PDL rewrite support.
19*6ae7f66fSJacques Pienaar #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
20*6ae7f66fSJacques Pienaar #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
21*6ae7f66fSJacques Pienaar #include "mlir/Dialect/PDL/IR/PDLOps.h"
22*6ae7f66fSJacques Pienaar 
238c66344eSRiver Riddle static LogicalResult
convertPDLToPDLInterp(ModuleOp pdlModule,DenseMap<Operation *,PDLPatternConfigSet * > & configMap)248c66344eSRiver Riddle convertPDLToPDLInterp(ModuleOp pdlModule,
258c66344eSRiver Riddle                       DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
2679d7f618SChris Lattner   // Skip the conversion if the module doesn't contain pdl.
27f55ed889SKazu Hirata   if (pdlModule.getOps<pdl::PatternOp>().empty())
2879d7f618SChris Lattner     return success();
2979d7f618SChris Lattner 
3079d7f618SChris Lattner   // Simplify the provided PDL module. Note that we can't use the canonicalizer
3179d7f618SChris Lattner   // here because it would create a cyclic dependency.
3279d7f618SChris Lattner   auto simplifyFn = [](Operation *op) {
3379d7f618SChris Lattner     // TODO: Add folding here if ever necessary.
3479d7f618SChris Lattner     if (isOpTriviallyDead(op))
3579d7f618SChris Lattner       op->erase();
3679d7f618SChris Lattner   };
3779d7f618SChris Lattner   pdlModule.getBody()->walk(simplifyFn);
3879d7f618SChris Lattner 
3979d7f618SChris Lattner   /// Lower the PDL pattern module to the interpreter dialect.
4094a30928Srkayaith   PassManager pdlPipeline(pdlModule->getName());
4179d7f618SChris Lattner #ifdef NDEBUG
4279d7f618SChris Lattner   // We don't want to incur the hit of running the verifier when in release
4379d7f618SChris Lattner   // mode.
4479d7f618SChris Lattner   pdlPipeline.enableVerifier(false);
4579d7f618SChris Lattner #endif
468c66344eSRiver Riddle   pdlPipeline.addPass(createPDLToPDLInterpPass(configMap));
4779d7f618SChris Lattner   if (failed(pdlPipeline.run(pdlModule)))
4879d7f618SChris Lattner     return failure();
4979d7f618SChris Lattner 
5079d7f618SChris Lattner   // Simplify again after running the lowering pipeline.
5179d7f618SChris Lattner   pdlModule.getBody()->walk(simplifyFn);
5279d7f618SChris Lattner   return success();
5379d7f618SChris Lattner }
54*6ae7f66fSJacques Pienaar #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
5579d7f618SChris Lattner 
5679d7f618SChris Lattner //===----------------------------------------------------------------------===//
5779d7f618SChris Lattner // FrozenRewritePatternSet
5879d7f618SChris Lattner //===----------------------------------------------------------------------===//
5979d7f618SChris Lattner 
FrozenRewritePatternSet()6079d7f618SChris Lattner FrozenRewritePatternSet::FrozenRewritePatternSet()
6179d7f618SChris Lattner     : impl(std::make_shared<Impl>()) {}
6279d7f618SChris Lattner 
FrozenRewritePatternSet(RewritePatternSet && patterns,ArrayRef<std::string> disabledPatternLabels,ArrayRef<std::string> enabledPatternLabels)630289a269SRiver Riddle FrozenRewritePatternSet::FrozenRewritePatternSet(
640289a269SRiver Riddle     RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
650289a269SRiver Riddle     ArrayRef<std::string> enabledPatternLabels)
6679d7f618SChris Lattner     : impl(std::make_shared<Impl>()) {
670289a269SRiver Riddle   DenseSet<StringRef> disabledPatterns, enabledPatterns;
680289a269SRiver Riddle   disabledPatterns.insert(disabledPatternLabels.begin(),
690289a269SRiver Riddle                           disabledPatternLabels.end());
700289a269SRiver Riddle   enabledPatterns.insert(enabledPatternLabels.begin(),
710289a269SRiver Riddle                          enabledPatternLabels.end());
720289a269SRiver Riddle 
7376f3c2f3SRiver Riddle   // Functor used to walk all of the operations registered in the context. This
7476f3c2f3SRiver Riddle   // is useful for patterns that get applied to multiple operations, such as
7576f3c2f3SRiver Riddle   // interface and trait based patterns.
76edc6c0ecSRiver Riddle   std::vector<RegisteredOperationName> opInfos;
77edc6c0ecSRiver Riddle   auto addToOpsWhen =
78edc6c0ecSRiver Riddle       [&](std::unique_ptr<RewritePattern> &pattern,
79edc6c0ecSRiver Riddle           function_ref<bool(RegisteredOperationName)> callbackFn) {
80edc6c0ecSRiver Riddle         if (opInfos.empty())
81edc6c0ecSRiver Riddle           opInfos = pattern->getContext()->getRegisteredOperations();
82edc6c0ecSRiver Riddle         for (RegisteredOperationName info : opInfos)
83edc6c0ecSRiver Riddle           if (callbackFn(info))
84edc6c0ecSRiver Riddle             impl->nativeOpSpecificPatternMap[info].push_back(pattern.get());
8576f3c2f3SRiver Riddle         impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
8676f3c2f3SRiver Riddle       };
8776f3c2f3SRiver Riddle 
8876f3c2f3SRiver Riddle   for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
890289a269SRiver Riddle     // Don't add patterns that haven't been enabled by the user.
900289a269SRiver Riddle     if (!enabledPatterns.empty()) {
910289a269SRiver Riddle       auto isEnabledFn = [&](StringRef label) {
920289a269SRiver Riddle         return enabledPatterns.count(label);
930289a269SRiver Riddle       };
940289a269SRiver Riddle       if (!isEnabledFn(pat->getDebugName()) &&
950289a269SRiver Riddle           llvm::none_of(pat->getDebugLabels(), isEnabledFn))
960289a269SRiver Riddle         continue;
970289a269SRiver Riddle     }
980289a269SRiver Riddle     // Don't add patterns that have been disabled by the user.
990289a269SRiver Riddle     if (!disabledPatterns.empty()) {
1000289a269SRiver Riddle       auto isDisabledFn = [&](StringRef label) {
1010289a269SRiver Riddle         return disabledPatterns.count(label);
1020289a269SRiver Riddle       };
1030289a269SRiver Riddle       if (isDisabledFn(pat->getDebugName()) ||
1040289a269SRiver Riddle           llvm::any_of(pat->getDebugLabels(), isDisabledFn))
1050289a269SRiver Riddle         continue;
1060289a269SRiver Riddle     }
1070289a269SRiver Riddle 
108bef481dfSFangrui Song     if (std::optional<OperationName> rootName = pat->getRootKind()) {
10976f3c2f3SRiver Riddle       impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
11076f3c2f3SRiver Riddle       impl->nativeOpSpecificPatternList.push_back(std::move(pat));
11176f3c2f3SRiver Riddle       continue;
11276f3c2f3SRiver Riddle     }
1130a81ace0SKazu Hirata     if (std::optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
114edc6c0ecSRiver Riddle       addToOpsWhen(pat, [&](RegisteredOperationName info) {
115edc6c0ecSRiver Riddle         return info.hasInterface(*interfaceID);
11676f3c2f3SRiver Riddle       });
11776f3c2f3SRiver Riddle       continue;
11876f3c2f3SRiver Riddle     }
1190a81ace0SKazu Hirata     if (std::optional<TypeID> traitID = pat->getRootTraitID()) {
120edc6c0ecSRiver Riddle       addToOpsWhen(pat, [&](RegisteredOperationName info) {
121edc6c0ecSRiver Riddle         return info.hasTrait(*traitID);
12276f3c2f3SRiver Riddle       });
12376f3c2f3SRiver Riddle       continue;
12476f3c2f3SRiver Riddle     }
12576f3c2f3SRiver Riddle     impl->nativeAnyOpPatterns.push_back(std::move(pat));
12676f3c2f3SRiver Riddle   }
12779d7f618SChris Lattner 
128*6ae7f66fSJacques Pienaar #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
12979d7f618SChris Lattner   // Generate the bytecode for the PDL patterns if any were provided.
13079d7f618SChris Lattner   PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
13179d7f618SChris Lattner   ModuleOp pdlModule = pdlPatterns.getModule();
13279d7f618SChris Lattner   if (!pdlModule)
13379d7f618SChris Lattner     return;
1348c66344eSRiver Riddle   DenseMap<Operation *, PDLPatternConfigSet *> configMap =
1358c66344eSRiver Riddle       pdlPatterns.takeConfigMap();
1368c66344eSRiver Riddle   if (failed(convertPDLToPDLInterp(pdlModule, configMap)))
13779d7f618SChris Lattner     llvm::report_fatal_error(
13879d7f618SChris Lattner         "failed to lower PDL pattern module to the PDL Interpreter");
13979d7f618SChris Lattner 
14079d7f618SChris Lattner   // Generate the pdl bytecode.
14179d7f618SChris Lattner   impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
1428c66344eSRiver Riddle       pdlModule, pdlPatterns.takeConfigs(), configMap,
1438c66344eSRiver Riddle       pdlPatterns.takeConstraintFunctions(),
14479d7f618SChris Lattner       pdlPatterns.takeRewriteFunctions());
145*6ae7f66fSJacques Pienaar #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
14679d7f618SChris Lattner }
14779d7f618SChris Lattner 
148e5639b3fSMehdi Amini FrozenRewritePatternSet::~FrozenRewritePatternSet() = default;
149