xref: /llvm-project/mlir/lib/IR/PatternMatch.cpp (revision 5aeb604c7ce417eea110f9803a6c5cb1cdbc5372)
1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/IRMapping.h"
12 #include "mlir/IR/Iterators.h"
13 #include "mlir/IR/RegionKindInterface.h"
14 #include "llvm/ADT/SmallPtrSet.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // PatternBenefit
20 //===----------------------------------------------------------------------===//
21 
PatternBenefit(unsigned benefit)22 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
23   assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
24          "This pattern match benefit is too large to represent");
25 }
26 
getBenefit() const27 unsigned short PatternBenefit::getBenefit() const {
28   assert(!isImpossibleToMatch() && "Pattern doesn't match");
29   return representation;
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // Pattern
34 //===----------------------------------------------------------------------===//
35 
36 //===----------------------------------------------------------------------===//
37 // OperationName Root Constructors
38 
Pattern(StringRef rootName,PatternBenefit benefit,MLIRContext * context,ArrayRef<StringRef> generatedNames)39 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
40                  MLIRContext *context, ArrayRef<StringRef> generatedNames)
41     : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
42               RootKind::OperationName, generatedNames, benefit, context) {}
43 
44 //===----------------------------------------------------------------------===//
45 // MatchAnyOpTypeTag Root Constructors
46 
Pattern(MatchAnyOpTypeTag tag,PatternBenefit benefit,MLIRContext * context,ArrayRef<StringRef> generatedNames)47 Pattern::Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit,
48                  MLIRContext *context, ArrayRef<StringRef> generatedNames)
49     : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
50 
51 //===----------------------------------------------------------------------===//
52 // MatchInterfaceOpTypeTag Root Constructors
53 
Pattern(MatchInterfaceOpTypeTag tag,TypeID interfaceID,PatternBenefit benefit,MLIRContext * context,ArrayRef<StringRef> generatedNames)54 Pattern::Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
55                  PatternBenefit benefit, MLIRContext *context,
56                  ArrayRef<StringRef> generatedNames)
57     : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
58               generatedNames, benefit, context) {}
59 
60 //===----------------------------------------------------------------------===//
61 // MatchTraitOpTypeTag Root Constructors
62 
Pattern(MatchTraitOpTypeTag tag,TypeID traitID,PatternBenefit benefit,MLIRContext * context,ArrayRef<StringRef> generatedNames)63 Pattern::Pattern(MatchTraitOpTypeTag tag, TypeID traitID,
64                  PatternBenefit benefit, MLIRContext *context,
65                  ArrayRef<StringRef> generatedNames)
66     : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
67               benefit, context) {}
68 
69 //===----------------------------------------------------------------------===//
70 // General Constructors
71 
Pattern(const void * rootValue,RootKind rootKind,ArrayRef<StringRef> generatedNames,PatternBenefit benefit,MLIRContext * context)72 Pattern::Pattern(const void *rootValue, RootKind rootKind,
73                  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
74                  MLIRContext *context)
75     : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
76       contextAndHasBoundedRecursion(context, false) {
77   if (generatedNames.empty())
78     return;
79   generatedOps.reserve(generatedNames.size());
80   std::transform(generatedNames.begin(), generatedNames.end(),
81                  std::back_inserter(generatedOps), [context](StringRef name) {
82                    return OperationName(name, context);
83                  });
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // RewritePattern
88 //===----------------------------------------------------------------------===//
89 
rewrite(Operation * op,PatternRewriter & rewriter) const90 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
91   llvm_unreachable("need to implement either matchAndRewrite or one of the "
92                    "rewrite functions!");
93 }
94 
match(Operation * op) const95 LogicalResult RewritePattern::match(Operation *op) const {
96   llvm_unreachable("need to implement either match or matchAndRewrite!");
97 }
98 
99 /// Out-of-line vtable anchor.
anchor()100 void RewritePattern::anchor() {}
101 
102 //===----------------------------------------------------------------------===//
103 // RewriterBase
104 //===----------------------------------------------------------------------===//
105 
classof(const OpBuilder::Listener * base)106 bool RewriterBase::Listener::classof(const OpBuilder::Listener *base) {
107   return base->getKind() == OpBuilder::ListenerBase::Kind::RewriterBaseListener;
108 }
109 
~RewriterBase()110 RewriterBase::~RewriterBase() {
111   // Out of line to provide a vtable anchor for the class.
112 }
113 
replaceAllOpUsesWith(Operation * from,ValueRange to)114 void RewriterBase::replaceAllOpUsesWith(Operation *from, ValueRange to) {
115   // Notify the listener that we're about to replace this op.
116   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
117     rewriteListener->notifyOperationReplaced(from, to);
118 
119   replaceAllUsesWith(from->getResults(), to);
120 }
121 
replaceAllOpUsesWith(Operation * from,Operation * to)122 void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) {
123   // Notify the listener that we're about to replace this op.
124   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
125     rewriteListener->notifyOperationReplaced(from, to);
126 
127   replaceAllUsesWith(from->getResults(), to->getResults());
128 }
129 
130 /// This method replaces the results of the operation with the specified list of
131 /// values. The number of provided values must match the number of results of
132 /// the operation. The replaced op is erased.
replaceOp(Operation * op,ValueRange newValues)133 void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
134   assert(op->getNumResults() == newValues.size() &&
135          "incorrect # of replacement values");
136 
137   // Replace all result uses. Also notifies the listener of modifications.
138   replaceAllOpUsesWith(op, newValues);
139 
140   // Erase op and notify listener.
141   eraseOp(op);
142 }
143 
144 /// This method replaces the results of the operation with the specified new op
145 /// (replacement). The number of results of the two operations must match. The
146 /// replaced op is erased.
replaceOp(Operation * op,Operation * newOp)147 void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
148   assert(op && newOp && "expected non-null op");
149   assert(op->getNumResults() == newOp->getNumResults() &&
150          "ops have different number of results");
151 
152   // Replace all result uses. Also notifies the listener of modifications.
153   replaceAllOpUsesWith(op, newOp->getResults());
154 
155   // Erase op and notify listener.
156   eraseOp(op);
157 }
158 
159 /// This method erases an operation that is known to have no uses. The uses of
160 /// the given operation *must* be known to be dead.
eraseOp(Operation * op)161 void RewriterBase::eraseOp(Operation *op) {
162   assert(op->use_empty() && "expected 'op' to have no uses");
163   auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
164 
165   // Fast path: If no listener is attached, the op can be dropped in one go.
166   if (!rewriteListener) {
167     op->erase();
168     return;
169   }
170 
171   // Helper function that erases a single op.
172   auto eraseSingleOp = [&](Operation *op) {
173 #ifndef NDEBUG
174     // All nested ops should have been erased already.
175     assert(
176         llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
177         "expected empty regions");
178     // All users should have been erased already if the op is in a region with
179     // SSA dominance.
180     if (!op->use_empty() && op->getParentOp())
181       assert(mayBeGraphRegion(*op->getParentRegion()) &&
182              "expected that op has no uses");
183 #endif // NDEBUG
184     rewriteListener->notifyOperationErased(op);
185 
186     // Explicitly drop all uses in case the op is in a graph region.
187     op->dropAllUses();
188     op->erase();
189   };
190 
191   // Nested ops must be erased one-by-one, so that listeners have a consistent
192   // view of the IR every time a notification is triggered. Users must be
193   // erased before definitions. I.e., post-order, reverse dominance.
194   std::function<void(Operation *)> eraseTree = [&](Operation *op) {
195     // Erase nested ops.
196     for (Region &r : llvm::reverse(op->getRegions())) {
197       // Erase all blocks in the right order. Successors should be erased
198       // before predecessors because successor blocks may use values defined
199       // in predecessor blocks. A post-order traversal of blocks within a
200       // region visits successors before predecessors. Repeat the traversal
201       // until the region is empty. (The block graph could be disconnected.)
202       while (!r.empty()) {
203         SmallVector<Block *> erasedBlocks;
204         // Some blocks may have invalid successor, use a set including nullptr
205         // to avoid null pointer.
206         llvm::SmallPtrSet<Block *, 4> visited{nullptr};
207         for (Block *b : llvm::post_order_ext(&r.front(), visited)) {
208           // Visit ops in reverse order.
209           for (Operation &op :
210                llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
211             eraseTree(&op);
212           // Do not erase the block immediately. This is not supprted by the
213           // post_order iterator.
214           erasedBlocks.push_back(b);
215         }
216         for (Block *b : erasedBlocks) {
217           // Explicitly drop all uses in case there is a cycle in the block
218           // graph.
219           for (BlockArgument bbArg : b->getArguments())
220             bbArg.dropAllUses();
221           b->dropAllUses();
222           eraseBlock(b);
223         }
224       }
225     }
226     // Then erase the enclosing op.
227     eraseSingleOp(op);
228   };
229 
230   eraseTree(op);
231 }
232 
eraseBlock(Block * block)233 void RewriterBase::eraseBlock(Block *block) {
234   assert(block->use_empty() && "expected 'block' to have no uses");
235 
236   for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
237     assert(op.use_empty() && "expected 'op' to have no uses");
238     eraseOp(&op);
239   }
240 
241   // Notify the listener that the block is about to be removed.
242   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
243     rewriteListener->notifyBlockErased(block);
244 
245   block->erase();
246 }
247 
finalizeOpModification(Operation * op)248 void RewriterBase::finalizeOpModification(Operation *op) {
249   // Notify the listener that the operation was modified.
250   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
251     rewriteListener->notifyOperationModified(op);
252 }
253 
replaceAllUsesExcept(Value from,Value to,const SmallPtrSetImpl<Operation * > & preservedUsers)254 void RewriterBase::replaceAllUsesExcept(
255     Value from, Value to, const SmallPtrSetImpl<Operation *> &preservedUsers) {
256   return replaceUsesWithIf(from, to, [&](OpOperand &use) {
257     Operation *user = use.getOwner();
258     return !preservedUsers.contains(user);
259   });
260 }
261 
replaceUsesWithIf(Value from,Value to,function_ref<bool (OpOperand &)> functor,bool * allUsesReplaced)262 void RewriterBase::replaceUsesWithIf(Value from, Value to,
263                                      function_ref<bool(OpOperand &)> functor,
264                                      bool *allUsesReplaced) {
265   bool allReplaced = true;
266   for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
267     bool replace = functor(operand);
268     if (replace)
269       modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
270     allReplaced &= replace;
271   }
272   if (allUsesReplaced)
273     *allUsesReplaced = allReplaced;
274 }
275 
replaceUsesWithIf(ValueRange from,ValueRange to,function_ref<bool (OpOperand &)> functor,bool * allUsesReplaced)276 void RewriterBase::replaceUsesWithIf(ValueRange from, ValueRange to,
277                                      function_ref<bool(OpOperand &)> functor,
278                                      bool *allUsesReplaced) {
279   assert(from.size() == to.size() && "incorrect number of replacements");
280   bool allReplaced = true;
281   for (auto it : llvm::zip_equal(from, to)) {
282     bool r;
283     replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
284                       /*allUsesReplaced=*/&r);
285     allReplaced &= r;
286   }
287   if (allUsesReplaced)
288     *allUsesReplaced = allReplaced;
289 }
290 
inlineBlockBefore(Block * source,Block * dest,Block::iterator before,ValueRange argValues)291 void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
292                                      Block::iterator before,
293                                      ValueRange argValues) {
294   assert(argValues.size() == source->getNumArguments() &&
295          "incorrect # of argument replacement values");
296 
297   // The source block will be deleted, so it should not have any users (i.e.,
298   // there should be no predecessors).
299   assert(source->hasNoPredecessors() &&
300          "expected 'source' to have no predecessors");
301 
302   if (dest->end() != before) {
303     // The source block will be inserted in the middle of the dest block, so
304     // the source block should have no successors. Otherwise, the remainder of
305     // the dest block would be unreachable.
306     assert(source->hasNoSuccessors() &&
307            "expected 'source' to have no successors");
308   } else {
309     // The source block will be inserted at the end of the dest block, so the
310     // dest block should have no successors. Otherwise, the inserted operations
311     // will be unreachable.
312     assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
313   }
314 
315   // Replace all of the successor arguments with the provided values.
316   for (auto it : llvm::zip(source->getArguments(), argValues))
317     replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
318 
319   // Move operations from the source block to the dest block and erase the
320   // source block.
321   if (!listener) {
322     // Fast path: If no listener is attached, move all operations at once.
323     dest->getOperations().splice(before, source->getOperations());
324   } else {
325     while (!source->empty())
326       moveOpBefore(&source->front(), dest, before);
327   }
328 
329   // Erase the source block.
330   assert(source->empty() && "expected 'source' to be empty");
331   eraseBlock(source);
332 }
333 
inlineBlockBefore(Block * source,Operation * op,ValueRange argValues)334 void RewriterBase::inlineBlockBefore(Block *source, Operation *op,
335                                      ValueRange argValues) {
336   inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
337 }
338 
mergeBlocks(Block * source,Block * dest,ValueRange argValues)339 void RewriterBase::mergeBlocks(Block *source, Block *dest,
340                                ValueRange argValues) {
341   inlineBlockBefore(source, dest, dest->end(), argValues);
342 }
343 
344 /// Split the operations starting at "before" (inclusive) out of the given
345 /// block into a new block, and return it.
splitBlock(Block * block,Block::iterator before)346 Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
347   // Fast path: If no listener is attached, split the block directly.
348   if (!listener)
349     return block->splitBlock(before);
350 
351   // `createBlock` sets the insertion point at the beginning of the new block.
352   InsertionGuard g(*this);
353   Block *newBlock =
354       createBlock(block->getParent(), std::next(block->getIterator()));
355 
356   // If `before` points to end of the block, no ops should be moved.
357   if (before == block->end())
358     return newBlock;
359 
360   // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
361   // Stop when the operation pointed to by `before` has been moved.
362   while (before->getBlock() != newBlock)
363     moveOpBefore(&block->back(), newBlock, newBlock->begin());
364 
365   return newBlock;
366 }
367 
368 /// Move the blocks that belong to "region" before the given position in
369 /// another region.  The two regions must be different.  The caller is in
370 /// charge to update create the operation transferring the control flow to the
371 /// region and pass it the correct block arguments.
inlineRegionBefore(Region & region,Region & parent,Region::iterator before)372 void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
373                                       Region::iterator before) {
374   // Fast path: If no listener is attached, move all blocks at once.
375   if (!listener) {
376     parent.getBlocks().splice(before, region.getBlocks());
377     return;
378   }
379 
380   // Move blocks from the beginning of the region one-by-one.
381   while (!region.empty())
382     moveBlockBefore(&region.front(), &parent, before);
383 }
inlineRegionBefore(Region & region,Block * before)384 void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
385   inlineRegionBefore(region, *before->getParent(), before->getIterator());
386 }
387 
moveBlockBefore(Block * block,Block * anotherBlock)388 void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
389   moveBlockBefore(block, anotherBlock->getParent(),
390                   anotherBlock->getIterator());
391 }
392 
moveBlockBefore(Block * block,Region * region,Region::iterator iterator)393 void RewriterBase::moveBlockBefore(Block *block, Region *region,
394                                    Region::iterator iterator) {
395   Region *currentRegion = block->getParent();
396   Region::iterator nextIterator = std::next(block->getIterator());
397   block->moveBefore(region, iterator);
398   if (listener)
399     listener->notifyBlockInserted(block, /*previous=*/currentRegion,
400                                   /*previousIt=*/nextIterator);
401 }
402 
moveOpBefore(Operation * op,Operation * existingOp)403 void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
404   moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
405 }
406 
moveOpBefore(Operation * op,Block * block,Block::iterator iterator)407 void RewriterBase::moveOpBefore(Operation *op, Block *block,
408                                 Block::iterator iterator) {
409   Block *currentBlock = op->getBlock();
410   Block::iterator nextIterator = std::next(op->getIterator());
411   op->moveBefore(block, iterator);
412   if (listener)
413     listener->notifyOperationInserted(
414         op, /*previous=*/InsertPoint(currentBlock, nextIterator));
415 }
416 
moveOpAfter(Operation * op,Operation * existingOp)417 void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
418   moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
419 }
420 
moveOpAfter(Operation * op,Block * block,Block::iterator iterator)421 void RewriterBase::moveOpAfter(Operation *op, Block *block,
422                                Block::iterator iterator) {
423   assert(iterator != block->end() && "cannot move after end of block");
424   moveOpBefore(op, block, std::next(iterator));
425 }
426