xref: /llvm-project/mlir/lib/Rewrite/PatternApplicator.cpp (revision 78e0cca135076154abab21eadd146dc1dfd3549f)
1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an applicator that applies pattern rewrites based upon a
10 // user defined cost model.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Rewrite/PatternApplicator.h"
15 #include "ByteCode.h"
16 #include "llvm/Support/Debug.h"
17 
18 #define DEBUG_TYPE "pattern-application"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
PatternApplicator(const FrozenRewritePatternSet & frozenPatternList)23 PatternApplicator::PatternApplicator(
24     const FrozenRewritePatternSet &frozenPatternList)
25     : frozenPatternList(frozenPatternList) {
26   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
27     mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
28     bytecode->initializeMutableState(*mutableByteCodeState);
29   }
30 }
31 PatternApplicator::~PatternApplicator() = default;
32 
33 #ifndef NDEBUG
34 /// Log a message for a pattern that is impossible to match.
logImpossibleToMatch(const Pattern & pattern)35 static void logImpossibleToMatch(const Pattern &pattern) {
36   llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
37                << "' because it is impossible to match or cannot lead "
38                   "to legal IR (by cost model)\n";
39 }
40 
41 /// Log IR after pattern application.
getDumpRootOp(Operation * op)42 static Operation *getDumpRootOp(Operation *op) {
43   Operation *isolatedParent =
44       op->getParentWithTrait<mlir::OpTrait::IsIsolatedFromAbove>();
45   if (isolatedParent)
46     return isolatedParent;
47   return op;
48 }
logSucessfulPatternApplication(Operation * op)49 static void logSucessfulPatternApplication(Operation *op) {
50   llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n";
51   op->dump();
52   llvm::dbgs() << "\n\n";
53 }
54 #endif
55 
applyCostModel(CostModel model)56 void PatternApplicator::applyCostModel(CostModel model) {
57   // Apply the cost model to the bytecode patterns first, and then the native
58   // patterns.
59   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
60     for (const auto &it : llvm::enumerate(bytecode->getPatterns()))
61       mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
62   }
63 
64   // Copy over the patterns so that we can sort by benefit based on the cost
65   // model. Patterns that are already impossible to match are ignored.
66   patterns.clear();
67   for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) {
68     for (const RewritePattern *pattern : it.second) {
69       if (pattern->getBenefit().isImpossibleToMatch())
70         LLVM_DEBUG(logImpossibleToMatch(*pattern));
71       else
72         patterns[it.first].push_back(pattern);
73     }
74   }
75   anyOpPatterns.clear();
76   for (const RewritePattern &pattern :
77        frozenPatternList.getMatchAnyOpNativePatterns()) {
78     if (pattern.getBenefit().isImpossibleToMatch())
79       LLVM_DEBUG(logImpossibleToMatch(pattern));
80     else
81       anyOpPatterns.push_back(&pattern);
82   }
83 
84   // Sort the patterns using the provided cost model.
85   llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
86   auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
87     return benefits[lhs] > benefits[rhs];
88   };
89   auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
90     // Special case for one pattern in the list, which is the most common case.
91     if (list.size() == 1) {
92       if (model(*list.front()).isImpossibleToMatch()) {
93         LLVM_DEBUG(logImpossibleToMatch(*list.front()));
94         list.clear();
95       }
96       return;
97     }
98 
99     // Collect the dynamic benefits for the current pattern list.
100     benefits.clear();
101     for (const Pattern *pat : list)
102       benefits.try_emplace(pat, model(*pat));
103 
104     // Sort patterns with highest benefit first, and remove those that are
105     // impossible to match.
106     std::stable_sort(list.begin(), list.end(), cmp);
107     while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
108       LLVM_DEBUG(logImpossibleToMatch(*list.back()));
109       list.pop_back();
110     }
111   };
112   for (auto &it : patterns)
113     processPatternList(it.second);
114   processPatternList(anyOpPatterns);
115 }
116 
walkAllPatterns(function_ref<void (const Pattern &)> walk)117 void PatternApplicator::walkAllPatterns(
118     function_ref<void(const Pattern &)> walk) {
119   for (const auto &it : frozenPatternList.getOpSpecificNativePatterns())
120     for (const auto &pattern : it.second)
121       walk(*pattern);
122   for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns())
123     walk(it);
124   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
125     for (const Pattern &it : bytecode->getPatterns())
126       walk(it);
127   }
128 }
129 
matchAndRewrite(Operation * op,PatternRewriter & rewriter,function_ref<bool (const Pattern &)> canApply,function_ref<void (const Pattern &)> onFailure,function_ref<LogicalResult (const Pattern &)> onSuccess)130 LogicalResult PatternApplicator::matchAndRewrite(
131     Operation *op, PatternRewriter &rewriter,
132     function_ref<bool(const Pattern &)> canApply,
133     function_ref<void(const Pattern &)> onFailure,
134     function_ref<LogicalResult(const Pattern &)> onSuccess) {
135   // Before checking native patterns, first match against the bytecode. This
136   // won't automatically perform any rewrites so there is no need to worry about
137   // conflicts.
138   SmallVector<PDLByteCode::MatchResult, 4> pdlMatches;
139   const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
140   if (bytecode)
141     bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
142 
143   // Check to see if there are patterns matching this specific operation type.
144   MutableArrayRef<const RewritePattern *> opPatterns;
145   auto patternIt = patterns.find(op->getName());
146   if (patternIt != patterns.end())
147     opPatterns = patternIt->second;
148 
149   // Process the patterns for that match the specific operation type, and any
150   // operation type in an interleaved fashion.
151   unsigned opIt = 0, opE = opPatterns.size();
152   unsigned anyIt = 0, anyE = anyOpPatterns.size();
153   unsigned pdlIt = 0, pdlE = pdlMatches.size();
154   LogicalResult result = failure();
155   do {
156     // Find the next pattern with the highest benefit.
157     const Pattern *bestPattern = nullptr;
158     unsigned *bestPatternIt = &opIt;
159 
160     /// Operation specific patterns.
161     if (opIt < opE)
162       bestPattern = opPatterns[opIt];
163     /// Operation agnostic patterns.
164     if (anyIt < anyE &&
165         (!bestPattern ||
166          bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
167       bestPatternIt = &anyIt;
168       bestPattern = anyOpPatterns[anyIt];
169     }
170 
171     const PDLByteCode::MatchResult *pdlMatch = nullptr;
172     /// PDL patterns.
173     if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
174                                              pdlMatches[pdlIt].benefit)) {
175       bestPatternIt = &pdlIt;
176       pdlMatch = &pdlMatches[pdlIt];
177       bestPattern = pdlMatch->pattern;
178     }
179 
180     if (!bestPattern)
181       break;
182 
183     // Update the pattern iterator on failure so that this pattern isn't
184     // attempted again.
185     ++(*bestPatternIt);
186 
187     // Check that the pattern can be applied.
188     if (canApply && !canApply(*bestPattern))
189       continue;
190 
191     // Try to match and rewrite this pattern. The patterns are sorted by
192     // benefit, so if we match we can immediately rewrite. For PDL patterns, the
193     // match has already been performed, we just need to rewrite.
194     bool matched = false;
195     op->getContext()->executeAction<ApplyPatternAction>(
196         [&]() {
197           rewriter.setInsertionPoint(op);
198 #ifndef NDEBUG
199           // Operation `op` may be invalidated after applying the rewrite
200           // pattern.
201           Operation *dumpRootOp = getDumpRootOp(op);
202 #endif
203           if (pdlMatch) {
204             result =
205                 bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
206           } else {
207             LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
208                                     << bestPattern->getDebugName() << "\"\n");
209 
210             const auto *pattern =
211                 static_cast<const RewritePattern *>(bestPattern);
212             result = pattern->matchAndRewrite(op, rewriter);
213 
214             LLVM_DEBUG(llvm::dbgs()
215                        << "\"" << bestPattern->getDebugName() << "\" result "
216                        << succeeded(result) << "\n");
217           }
218 
219           // Process the result of the pattern application.
220           if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
221             result = failure();
222           if (succeeded(result)) {
223             LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
224             matched = true;
225             return;
226           }
227 
228           // Perform any necessary cleanups.
229           if (onFailure)
230             onFailure(*bestPattern);
231         },
232         {op}, *bestPattern);
233     if (matched)
234       break;
235   } while (true);
236 
237   if (mutableByteCodeState)
238     mutableByteCodeState->cleanupAfterMatchAndRewrite();
239   return result;
240 }
241