xref: /llvm-project/mlir/lib/Rewrite/PatternApplicator.cpp (revision 78e0cca135076154abab21eadd146dc1dfd3549f)
1b6eb26fdSRiver Riddle //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
2b6eb26fdSRiver Riddle //
3b6eb26fdSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b6eb26fdSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5b6eb26fdSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b6eb26fdSRiver Riddle //
7b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===//
8b6eb26fdSRiver Riddle //
9b6eb26fdSRiver Riddle // This file implements an applicator that applies pattern rewrites based upon a
10b6eb26fdSRiver Riddle // user defined cost model.
11b6eb26fdSRiver Riddle //
12b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===//
13b6eb26fdSRiver Riddle 
14b6eb26fdSRiver Riddle #include "mlir/Rewrite/PatternApplicator.h"
15abfd1a8bSRiver Riddle #include "ByteCode.h"
16b6eb26fdSRiver Riddle #include "llvm/Support/Debug.h"
17b6eb26fdSRiver Riddle 
186176a8f9SFrederik Gossen #define DEBUG_TYPE "pattern-application"
1976f3c2f3SRiver Riddle 
20b6eb26fdSRiver Riddle using namespace mlir;
21abfd1a8bSRiver Riddle using namespace mlir::detail;
22abfd1a8bSRiver Riddle 
PatternApplicator(const FrozenRewritePatternSet & frozenPatternList)23abfd1a8bSRiver Riddle PatternApplicator::PatternApplicator(
2479d7f618SChris Lattner     const FrozenRewritePatternSet &frozenPatternList)
25abfd1a8bSRiver Riddle     : frozenPatternList(frozenPatternList) {
26abfd1a8bSRiver Riddle   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
27abfd1a8bSRiver Riddle     mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
28abfd1a8bSRiver Riddle     bytecode->initializeMutableState(*mutableByteCodeState);
29abfd1a8bSRiver Riddle   }
30abfd1a8bSRiver Riddle }
31e5639b3fSMehdi Amini PatternApplicator::~PatternApplicator() = default;
32b6eb26fdSRiver Riddle 
336176a8f9SFrederik Gossen #ifndef NDEBUG
3476f3c2f3SRiver Riddle /// Log a message for a pattern that is impossible to match.
logImpossibleToMatch(const Pattern & pattern)3576f3c2f3SRiver Riddle static void logImpossibleToMatch(const Pattern &pattern) {
3676f3c2f3SRiver Riddle   llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
3776f3c2f3SRiver Riddle                << "' because it is impossible to match or cannot lead "
3876f3c2f3SRiver Riddle                   "to legal IR (by cost model)\n";
3976f3c2f3SRiver Riddle }
40b6eb26fdSRiver Riddle 
416176a8f9SFrederik Gossen /// Log IR after pattern application.
getDumpRootOp(Operation * op)426176a8f9SFrederik Gossen static Operation *getDumpRootOp(Operation *op) {
43*78e0cca1SRobert Konicar   Operation *isolatedParent =
44*78e0cca1SRobert Konicar       op->getParentWithTrait<mlir::OpTrait::IsIsolatedFromAbove>();
45*78e0cca1SRobert Konicar   if (isolatedParent)
46*78e0cca1SRobert Konicar     return isolatedParent;
47*78e0cca1SRobert Konicar   return op;
486176a8f9SFrederik Gossen }
logSucessfulPatternApplication(Operation * op)496176a8f9SFrederik Gossen static void logSucessfulPatternApplication(Operation *op) {
506176a8f9SFrederik Gossen   llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n";
516176a8f9SFrederik Gossen   op->dump();
526176a8f9SFrederik Gossen   llvm::dbgs() << "\n\n";
536176a8f9SFrederik Gossen }
546176a8f9SFrederik Gossen #endif
556176a8f9SFrederik Gossen 
applyCostModel(CostModel model)56b6eb26fdSRiver Riddle void PatternApplicator::applyCostModel(CostModel model) {
57abfd1a8bSRiver Riddle   // Apply the cost model to the bytecode patterns first, and then the native
58abfd1a8bSRiver Riddle   // patterns.
59abfd1a8bSRiver Riddle   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
60e4853be2SMehdi Amini     for (const auto &it : llvm::enumerate(bytecode->getPatterns()))
61abfd1a8bSRiver Riddle       mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
62abfd1a8bSRiver Riddle   }
63abfd1a8bSRiver Riddle 
6476f3c2f3SRiver Riddle   // Copy over the patterns so that we can sort by benefit based on the cost
6576f3c2f3SRiver Riddle   // model. Patterns that are already impossible to match are ignored.
66b6eb26fdSRiver Riddle   patterns.clear();
6776f3c2f3SRiver Riddle   for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) {
6876f3c2f3SRiver Riddle     for (const RewritePattern *pattern : it.second) {
6976f3c2f3SRiver Riddle       if (pattern->getBenefit().isImpossibleToMatch())
706176a8f9SFrederik Gossen         LLVM_DEBUG(logImpossibleToMatch(*pattern));
71b6eb26fdSRiver Riddle       else
7276f3c2f3SRiver Riddle         patterns[it.first].push_back(pattern);
7376f3c2f3SRiver Riddle     }
7476f3c2f3SRiver Riddle   }
7576f3c2f3SRiver Riddle   anyOpPatterns.clear();
7676f3c2f3SRiver Riddle   for (const RewritePattern &pattern :
7776f3c2f3SRiver Riddle        frozenPatternList.getMatchAnyOpNativePatterns()) {
7876f3c2f3SRiver Riddle     if (pattern.getBenefit().isImpossibleToMatch())
796176a8f9SFrederik Gossen       LLVM_DEBUG(logImpossibleToMatch(pattern));
8076f3c2f3SRiver Riddle     else
8176f3c2f3SRiver Riddle       anyOpPatterns.push_back(&pattern);
82b6eb26fdSRiver Riddle   }
83b6eb26fdSRiver Riddle 
84b6eb26fdSRiver Riddle   // Sort the patterns using the provided cost model.
853fffffa8SRiver Riddle   llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
863fffffa8SRiver Riddle   auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
87b6eb26fdSRiver Riddle     return benefits[lhs] > benefits[rhs];
88b6eb26fdSRiver Riddle   };
893fffffa8SRiver Riddle   auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
90b6eb26fdSRiver Riddle     // Special case for one pattern in the list, which is the most common case.
91b6eb26fdSRiver Riddle     if (list.size() == 1) {
92b6eb26fdSRiver Riddle       if (model(*list.front()).isImpossibleToMatch()) {
936176a8f9SFrederik Gossen         LLVM_DEBUG(logImpossibleToMatch(*list.front()));
94b6eb26fdSRiver Riddle         list.clear();
95b6eb26fdSRiver Riddle       }
96b6eb26fdSRiver Riddle       return;
97b6eb26fdSRiver Riddle     }
98b6eb26fdSRiver Riddle 
99b6eb26fdSRiver Riddle     // Collect the dynamic benefits for the current pattern list.
100b6eb26fdSRiver Riddle     benefits.clear();
1013fffffa8SRiver Riddle     for (const Pattern *pat : list)
102b6eb26fdSRiver Riddle       benefits.try_emplace(pat, model(*pat));
103b6eb26fdSRiver Riddle 
104b6eb26fdSRiver Riddle     // Sort patterns with highest benefit first, and remove those that are
105b6eb26fdSRiver Riddle     // impossible to match.
106b6eb26fdSRiver Riddle     std::stable_sort(list.begin(), list.end(), cmp);
1076176a8f9SFrederik Gossen     while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
1086176a8f9SFrederik Gossen       LLVM_DEBUG(logImpossibleToMatch(*list.back()));
1096176a8f9SFrederik Gossen       list.pop_back();
1106176a8f9SFrederik Gossen     }
111b6eb26fdSRiver Riddle   };
112b6eb26fdSRiver Riddle   for (auto &it : patterns)
113b6eb26fdSRiver Riddle     processPatternList(it.second);
114b6eb26fdSRiver Riddle   processPatternList(anyOpPatterns);
115b6eb26fdSRiver Riddle }
116b6eb26fdSRiver Riddle 
walkAllPatterns(function_ref<void (const Pattern &)> walk)117b6eb26fdSRiver Riddle void PatternApplicator::walkAllPatterns(
118b6eb26fdSRiver Riddle     function_ref<void(const Pattern &)> walk) {
11976f3c2f3SRiver Riddle   for (const auto &it : frozenPatternList.getOpSpecificNativePatterns())
12076f3c2f3SRiver Riddle     for (const auto &pattern : it.second)
12176f3c2f3SRiver Riddle       walk(*pattern);
12276f3c2f3SRiver Riddle   for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns())
1233fffffa8SRiver Riddle     walk(it);
124abfd1a8bSRiver Riddle   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
125abfd1a8bSRiver Riddle     for (const Pattern &it : bytecode->getPatterns())
126abfd1a8bSRiver Riddle       walk(it);
127abfd1a8bSRiver Riddle   }
128b6eb26fdSRiver Riddle }
129b6eb26fdSRiver Riddle 
matchAndRewrite(Operation * op,PatternRewriter & rewriter,function_ref<bool (const Pattern &)> canApply,function_ref<void (const Pattern &)> onFailure,function_ref<LogicalResult (const Pattern &)> onSuccess)130b6eb26fdSRiver Riddle LogicalResult PatternApplicator::matchAndRewrite(
131b6eb26fdSRiver Riddle     Operation *op, PatternRewriter &rewriter,
132b6eb26fdSRiver Riddle     function_ref<bool(const Pattern &)> canApply,
133b6eb26fdSRiver Riddle     function_ref<void(const Pattern &)> onFailure,
134b6eb26fdSRiver Riddle     function_ref<LogicalResult(const Pattern &)> onSuccess) {
135abfd1a8bSRiver Riddle   // Before checking native patterns, first match against the bytecode. This
136abfd1a8bSRiver Riddle   // won't automatically perform any rewrites so there is no need to worry about
137abfd1a8bSRiver Riddle   // conflicts.
138abfd1a8bSRiver Riddle   SmallVector<PDLByteCode::MatchResult, 4> pdlMatches;
139abfd1a8bSRiver Riddle   const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
140abfd1a8bSRiver Riddle   if (bytecode)
141abfd1a8bSRiver Riddle     bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
142abfd1a8bSRiver Riddle 
143b6eb26fdSRiver Riddle   // Check to see if there are patterns matching this specific operation type.
1443fffffa8SRiver Riddle   MutableArrayRef<const RewritePattern *> opPatterns;
145b6eb26fdSRiver Riddle   auto patternIt = patterns.find(op->getName());
146b6eb26fdSRiver Riddle   if (patternIt != patterns.end())
147b6eb26fdSRiver Riddle     opPatterns = patternIt->second;
148b6eb26fdSRiver Riddle 
149b6eb26fdSRiver Riddle   // Process the patterns for that match the specific operation type, and any
150b6eb26fdSRiver Riddle   // operation type in an interleaved fashion.
15185ab413bSRiver Riddle   unsigned opIt = 0, opE = opPatterns.size();
15285ab413bSRiver Riddle   unsigned anyIt = 0, anyE = anyOpPatterns.size();
15385ab413bSRiver Riddle   unsigned pdlIt = 0, pdlE = pdlMatches.size();
15485ab413bSRiver Riddle   LogicalResult result = failure();
15585ab413bSRiver Riddle   do {
156abfd1a8bSRiver Riddle     // Find the next pattern with the highest benefit.
157abfd1a8bSRiver Riddle     const Pattern *bestPattern = nullptr;
15885ab413bSRiver Riddle     unsigned *bestPatternIt = &opIt;
15985ab413bSRiver Riddle 
160abfd1a8bSRiver Riddle     /// Operation specific patterns.
16185ab413bSRiver Riddle     if (opIt < opE)
16285ab413bSRiver Riddle       bestPattern = opPatterns[opIt];
163abfd1a8bSRiver Riddle     /// Operation agnostic patterns.
16485ab413bSRiver Riddle     if (anyIt < anyE &&
16585ab413bSRiver Riddle         (!bestPattern ||
16685ab413bSRiver Riddle          bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
16785ab413bSRiver Riddle       bestPatternIt = &anyIt;
16885ab413bSRiver Riddle       bestPattern = anyOpPatterns[anyIt];
16985ab413bSRiver Riddle     }
1706ae7f66fSJacques Pienaar 
1716ae7f66fSJacques Pienaar     const PDLByteCode::MatchResult *pdlMatch = nullptr;
172abfd1a8bSRiver Riddle     /// PDL patterns.
17385ab413bSRiver Riddle     if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
17485ab413bSRiver Riddle                                              pdlMatches[pdlIt].benefit)) {
17585ab413bSRiver Riddle       bestPatternIt = &pdlIt;
17685ab413bSRiver Riddle       pdlMatch = &pdlMatches[pdlIt];
17785ab413bSRiver Riddle       bestPattern = pdlMatch->pattern;
178abfd1a8bSRiver Riddle     }
1796ae7f66fSJacques Pienaar 
180abfd1a8bSRiver Riddle     if (!bestPattern)
181abfd1a8bSRiver Riddle       break;
182b6eb26fdSRiver Riddle 
18385ab413bSRiver Riddle     // Update the pattern iterator on failure so that this pattern isn't
18485ab413bSRiver Riddle     // attempted again.
18585ab413bSRiver Riddle     ++(*bestPatternIt);
18685ab413bSRiver Riddle 
187b6eb26fdSRiver Riddle     // Check that the pattern can be applied.
188abfd1a8bSRiver Riddle     if (canApply && !canApply(*bestPattern))
189abfd1a8bSRiver Riddle       continue;
190b6eb26fdSRiver Riddle 
191b6eb26fdSRiver Riddle     // Try to match and rewrite this pattern. The patterns are sorted by
192abfd1a8bSRiver Riddle     // benefit, so if we match we can immediately rewrite. For PDL patterns, the
193abfd1a8bSRiver Riddle     // match has already been performed, we just need to rewrite.
194e24b91b0SMehdi Amini     bool matched = false;
195e24b91b0SMehdi Amini     op->getContext()->executeAction<ApplyPatternAction>(
196e24b91b0SMehdi Amini         [&]() {
197b6eb26fdSRiver Riddle           rewriter.setInsertionPoint(op);
1986176a8f9SFrederik Gossen #ifndef NDEBUG
199e24b91b0SMehdi Amini           // Operation `op` may be invalidated after applying the rewrite
200e24b91b0SMehdi Amini           // pattern.
2016176a8f9SFrederik Gossen           Operation *dumpRootOp = getDumpRootOp(op);
2026176a8f9SFrederik Gossen #endif
203abfd1a8bSRiver Riddle           if (pdlMatch) {
204e24b91b0SMehdi Amini             result =
205e24b91b0SMehdi Amini                 bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
206abfd1a8bSRiver Riddle           } else {
2078c66344eSRiver Riddle             LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
2088c66344eSRiver Riddle                                     << bestPattern->getDebugName() << "\"\n");
2098c66344eSRiver Riddle 
210e24b91b0SMehdi Amini             const auto *pattern =
211e24b91b0SMehdi Amini                 static_cast<const RewritePattern *>(bestPattern);
21285ab413bSRiver Riddle             result = pattern->matchAndRewrite(op, rewriter);
213e2a77644SButygin 
214e24b91b0SMehdi Amini             LLVM_DEBUG(llvm::dbgs()
215e24b91b0SMehdi Amini                        << "\"" << bestPattern->getDebugName() << "\" result "
216e24b91b0SMehdi Amini                        << succeeded(result) << "\n");
217abfd1a8bSRiver Riddle           }
2188c66344eSRiver Riddle 
2198c66344eSRiver Riddle           // Process the result of the pattern application.
2208c66344eSRiver Riddle           if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
2218c66344eSRiver Riddle             result = failure();
2226176a8f9SFrederik Gossen           if (succeeded(result)) {
2236176a8f9SFrederik Gossen             LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
224e24b91b0SMehdi Amini             matched = true;
225e24b91b0SMehdi Amini             return;
2266176a8f9SFrederik Gossen           }
227b6eb26fdSRiver Riddle 
228abfd1a8bSRiver Riddle           // Perform any necessary cleanups.
229b6eb26fdSRiver Riddle           if (onFailure)
230abfd1a8bSRiver Riddle             onFailure(*bestPattern);
231e24b91b0SMehdi Amini         },
232e24b91b0SMehdi Amini         {op}, *bestPattern);
233e24b91b0SMehdi Amini     if (matched)
234e24b91b0SMehdi Amini       break;
23585ab413bSRiver Riddle   } while (true);
23685ab413bSRiver Riddle 
23785ab413bSRiver Riddle   if (mutableByteCodeState)
23885ab413bSRiver Riddle     mutableByteCodeState->cleanupAfterMatchAndRewrite();
23985ab413bSRiver Riddle   return result;
240b6eb26fdSRiver Riddle }
241