xref: /llvm-project/mlir/lib/IR/PatternMatch.cpp (revision 5aeb604c7ce417eea110f9803a6c5cb1cdbc5372)
17de0da95SChris Lattner //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
27de0da95SChris Lattner //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67de0da95SChris Lattner //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
87de0da95SChris Lattner 
97de0da95SChris Lattner #include "mlir/IR/PatternMatch.h"
106ae7f66fSJacques Pienaar #include "mlir/Config/mlir-config.h"
114d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
12695a5a6aSMatthias Springer #include "mlir/IR/Iterators.h"
13695a5a6aSMatthias Springer #include "mlir/IR/RegionKindInterface.h"
14*5aeb604cSMaheshRavishankar #include "llvm/ADT/SmallPtrSet.h"
156fb3d597SDiego Caballero 
167de0da95SChris Lattner using namespace mlir;
177de0da95SChris Lattner 
18b99bd771SRiver Riddle //===----------------------------------------------------------------------===//
19b99bd771SRiver Riddle // PatternBenefit
20b99bd771SRiver Riddle //===----------------------------------------------------------------------===//
21b99bd771SRiver Riddle 
PatternBenefit(unsigned benefit)227de0da95SChris Lattner PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
237de0da95SChris Lattner   assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
247de0da95SChris Lattner          "This pattern match benefit is too large to represent");
257de0da95SChris Lattner }
267de0da95SChris Lattner 
getBenefit() const277de0da95SChris Lattner unsigned short PatternBenefit::getBenefit() const {
283e98fbf4SRiver Riddle   assert(!isImpossibleToMatch() && "Pattern doesn't match");
297de0da95SChris Lattner   return representation;
307de0da95SChris Lattner }
317de0da95SChris Lattner 
327de0da95SChris Lattner //===----------------------------------------------------------------------===//
33b99bd771SRiver Riddle // Pattern
347de0da95SChris Lattner //===----------------------------------------------------------------------===//
357de0da95SChris Lattner 
3676f3c2f3SRiver Riddle //===----------------------------------------------------------------------===//
3776f3c2f3SRiver Riddle // OperationName Root Constructors
3876f3c2f3SRiver Riddle 
Pattern(StringRef rootName,PatternBenefit benefit,MLIRContext * context,ArrayRef<StringRef> generatedNames)3986a5323fSChris Lattner Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
4076f3c2f3SRiver Riddle                  MLIRContext *context, ArrayRef<StringRef> generatedNames)
4176f3c2f3SRiver Riddle     : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
4276f3c2f3SRiver Riddle               RootKind::OperationName, generatedNames, benefit, context) {}
4376f3c2f3SRiver Riddle 
4476f3c2f3SRiver Riddle //===----------------------------------------------------------------------===//
4576f3c2f3SRiver Riddle // MatchAnyOpTypeTag Root Constructors
4676f3c2f3SRiver Riddle 
Pattern(MatchAnyOpTypeTag tag,PatternBenefit benefit,MLIRContext * context,ArrayRef<StringRef> generatedNames)4776f3c2f3SRiver Riddle Pattern::Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit,
4876f3c2f3SRiver Riddle                  MLIRContext *context, ArrayRef<StringRef> generatedNames)
4976f3c2f3SRiver Riddle     : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
5076f3c2f3SRiver Riddle 
5176f3c2f3SRiver Riddle //===----------------------------------------------------------------------===//
5276f3c2f3SRiver Riddle // MatchInterfaceOpTypeTag Root Constructors
5376f3c2f3SRiver Riddle 
Pattern(MatchInterfaceOpTypeTag tag,TypeID interfaceID,PatternBenefit benefit,MLIRContext * context,ArrayRef<StringRef> generatedNames)5476f3c2f3SRiver Riddle Pattern::Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
5576f3c2f3SRiver Riddle                  PatternBenefit benefit, MLIRContext *context,
5676f3c2f3SRiver Riddle                  ArrayRef<StringRef> generatedNames)
5776f3c2f3SRiver Riddle     : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
5876f3c2f3SRiver Riddle               generatedNames, benefit, context) {}
5976f3c2f3SRiver Riddle 
6076f3c2f3SRiver Riddle //===----------------------------------------------------------------------===//
6176f3c2f3SRiver Riddle // MatchTraitOpTypeTag Root Constructors
6276f3c2f3SRiver Riddle 
Pattern(MatchTraitOpTypeTag tag,TypeID traitID,PatternBenefit benefit,MLIRContext * context,ArrayRef<StringRef> generatedNames)6376f3c2f3SRiver Riddle Pattern::Pattern(MatchTraitOpTypeTag tag, TypeID traitID,
6476f3c2f3SRiver Riddle                  PatternBenefit benefit, MLIRContext *context,
6576f3c2f3SRiver Riddle                  ArrayRef<StringRef> generatedNames)
6676f3c2f3SRiver Riddle     : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
6776f3c2f3SRiver Riddle               benefit, context) {}
6876f3c2f3SRiver Riddle 
6976f3c2f3SRiver Riddle //===----------------------------------------------------------------------===//
7076f3c2f3SRiver Riddle // General Constructors
7176f3c2f3SRiver Riddle 
Pattern(const void * rootValue,RootKind rootKind,ArrayRef<StringRef> generatedNames,PatternBenefit benefit,MLIRContext * context)7276f3c2f3SRiver Riddle Pattern::Pattern(const void *rootValue, RootKind rootKind,
7376f3c2f3SRiver Riddle                  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
7486a5323fSChris Lattner                  MLIRContext *context)
7576f3c2f3SRiver Riddle     : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
7676f3c2f3SRiver Riddle       contextAndHasBoundedRecursion(context, false) {
7776f3c2f3SRiver Riddle   if (generatedNames.empty())
7876f3c2f3SRiver Riddle     return;
79b99bd771SRiver Riddle   generatedOps.reserve(generatedNames.size());
80b99bd771SRiver Riddle   std::transform(generatedNames.begin(), generatedNames.end(),
81b99bd771SRiver Riddle                  std::back_inserter(generatedOps), [context](StringRef name) {
82b99bd771SRiver Riddle                    return OperationName(name, context);
83b99bd771SRiver Riddle                  });
84b99bd771SRiver Riddle }
853f2530cdSChris Lattner 
863f2530cdSChris Lattner //===----------------------------------------------------------------------===//
87b99bd771SRiver Riddle // RewritePattern
883f2530cdSChris Lattner //===----------------------------------------------------------------------===//
893f2530cdSChris Lattner 
rewrite(Operation * op,PatternRewriter & rewriter) const90f9d91531SRiver Riddle void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
915de726f4SRiver Riddle   llvm_unreachable("need to implement either matchAndRewrite or one of the "
925de726f4SRiver Riddle                    "rewrite functions!");
935de726f4SRiver Riddle }
945de726f4SRiver Riddle 
match(Operation * op) const953145427dSRiver Riddle LogicalResult RewritePattern::match(Operation *op) const {
965de726f4SRiver Riddle   llvm_unreachable("need to implement either match or matchAndRewrite!");
977de0da95SChris Lattner }
987de0da95SChris Lattner 
99b99bd771SRiver Riddle /// Out-of-line vtable anchor.
anchor()100b99bd771SRiver Riddle void RewritePattern::anchor() {}
101b99bd771SRiver Riddle 
102b99bd771SRiver Riddle //===----------------------------------------------------------------------===//
103ec10f066SRiver Riddle // RewriterBase
104b99bd771SRiver Riddle //===----------------------------------------------------------------------===//
105647f8cabSRiver Riddle 
classof(const OpBuilder::Listener * base)106c6532830SMatthias Springer bool RewriterBase::Listener::classof(const OpBuilder::Listener *base) {
107c6532830SMatthias Springer   return base->getKind() == OpBuilder::ListenerBase::Kind::RewriterBaseListener;
108c6532830SMatthias Springer }
109c6532830SMatthias Springer 
~RewriterBase()110ec10f066SRiver Riddle RewriterBase::~RewriterBase() {
1117de0da95SChris Lattner   // Out of line to provide a vtable anchor for the class.
1127de0da95SChris Lattner }
1137de0da95SChris Lattner 
replaceAllOpUsesWith(Operation * from,ValueRange to)11438113a08SMatthias Springer void RewriterBase::replaceAllOpUsesWith(Operation *from, ValueRange to) {
11538113a08SMatthias Springer   // Notify the listener that we're about to replace this op.
11638113a08SMatthias Springer   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
11738113a08SMatthias Springer     rewriteListener->notifyOperationReplaced(from, to);
11838113a08SMatthias Springer 
11938113a08SMatthias Springer   replaceAllUsesWith(from->getResults(), to);
12038113a08SMatthias Springer }
12138113a08SMatthias Springer 
replaceAllOpUsesWith(Operation * from,Operation * to)12238113a08SMatthias Springer void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) {
12338113a08SMatthias Springer   // Notify the listener that we're about to replace this op.
12438113a08SMatthias Springer   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
12538113a08SMatthias Springer     rewriteListener->notifyOperationReplaced(from, to);
12638113a08SMatthias Springer 
12738113a08SMatthias Springer   replaceAllUsesWith(from->getResults(), to->getResults());
12838113a08SMatthias Springer }
12938113a08SMatthias Springer 
130ec10f066SRiver Riddle /// This method replaces the results of the operation with the specified list of
131ec10f066SRiver Riddle /// values. The number of provided values must match the number of results of
13271d50c89SMatthias Springer /// the operation. The replaced op is erased.
replaceOp(Operation * op,ValueRange newValues)133ec10f066SRiver Riddle void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
13421f4b84cSMatthias Springer   assert(op->getNumResults() == newValues.size() &&
13521f4b84cSMatthias Springer          "incorrect # of replacement values");
13621f4b84cSMatthias Springer 
13759a92019SMatthias Springer   // Replace all result uses. Also notifies the listener of modifications.
138f1aa7837SMatthias Springer   replaceAllOpUsesWith(op, newValues);
1391427d0f0SChris Lattner 
140695a5a6aSMatthias Springer   // Erase op and notify listener.
14171d50c89SMatthias Springer   eraseOp(op);
14271d50c89SMatthias Springer }
14371d50c89SMatthias Springer 
14471d50c89SMatthias Springer /// This method replaces the results of the operation with the specified new op
14571d50c89SMatthias Springer /// (replacement). The number of results of the two operations must match. The
14671d50c89SMatthias Springer /// replaced op is erased.
replaceOp(Operation * op,Operation * newOp)14771d50c89SMatthias Springer void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
14871d50c89SMatthias Springer   assert(op && newOp && "expected non-null op");
14971d50c89SMatthias Springer   assert(op->getNumResults() == newOp->getNumResults() &&
15071d50c89SMatthias Springer          "ops have different number of results");
15171d50c89SMatthias Springer 
15259a92019SMatthias Springer   // Replace all result uses. Also notifies the listener of modifications.
153f1aa7837SMatthias Springer   replaceAllOpUsesWith(op, newOp->getResults());
15471d50c89SMatthias Springer 
155695a5a6aSMatthias Springer   // Erase op and notify listener.
15671d50c89SMatthias Springer   eraseOp(op);
1571427d0f0SChris Lattner }
1581427d0f0SChris Lattner 
159dfe09cc6SRiver Riddle /// This method erases an operation that is known to have no uses. The uses of
160dfe09cc6SRiver Riddle /// the given operation *must* be known to be dead.
eraseOp(Operation * op)161ec10f066SRiver Riddle void RewriterBase::eraseOp(Operation *op) {
162dfe09cc6SRiver Riddle   assert(op->use_empty() && "expected 'op' to have no uses");
163695a5a6aSMatthias Springer   auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
164695a5a6aSMatthias Springer 
165695a5a6aSMatthias Springer   // Fast path: If no listener is attached, the op can be dropped in one go.
166695a5a6aSMatthias Springer   if (!rewriteListener) {
167dfe09cc6SRiver Riddle     op->erase();
168695a5a6aSMatthias Springer     return;
169695a5a6aSMatthias Springer   }
170695a5a6aSMatthias Springer 
171695a5a6aSMatthias Springer   // Helper function that erases a single op.
172695a5a6aSMatthias Springer   auto eraseSingleOp = [&](Operation *op) {
173695a5a6aSMatthias Springer #ifndef NDEBUG
174695a5a6aSMatthias Springer     // All nested ops should have been erased already.
175695a5a6aSMatthias Springer     assert(
176695a5a6aSMatthias Springer         llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
177695a5a6aSMatthias Springer         "expected empty regions");
178695a5a6aSMatthias Springer     // All users should have been erased already if the op is in a region with
179695a5a6aSMatthias Springer     // SSA dominance.
180695a5a6aSMatthias Springer     if (!op->use_empty() && op->getParentOp())
181695a5a6aSMatthias Springer       assert(mayBeGraphRegion(*op->getParentRegion()) &&
182695a5a6aSMatthias Springer              "expected that op has no uses");
183695a5a6aSMatthias Springer #endif // NDEBUG
184914e6074SMatthias Springer     rewriteListener->notifyOperationErased(op);
185695a5a6aSMatthias Springer 
186695a5a6aSMatthias Springer     // Explicitly drop all uses in case the op is in a graph region.
187695a5a6aSMatthias Springer     op->dropAllUses();
188695a5a6aSMatthias Springer     op->erase();
189695a5a6aSMatthias Springer   };
190695a5a6aSMatthias Springer 
191695a5a6aSMatthias Springer   // Nested ops must be erased one-by-one, so that listeners have a consistent
192695a5a6aSMatthias Springer   // view of the IR every time a notification is triggered. Users must be
193695a5a6aSMatthias Springer   // erased before definitions. I.e., post-order, reverse dominance.
194695a5a6aSMatthias Springer   std::function<void(Operation *)> eraseTree = [&](Operation *op) {
195695a5a6aSMatthias Springer     // Erase nested ops.
196695a5a6aSMatthias Springer     for (Region &r : llvm::reverse(op->getRegions())) {
197695a5a6aSMatthias Springer       // Erase all blocks in the right order. Successors should be erased
198695a5a6aSMatthias Springer       // before predecessors because successor blocks may use values defined
199695a5a6aSMatthias Springer       // in predecessor blocks. A post-order traversal of blocks within a
200695a5a6aSMatthias Springer       // region visits successors before predecessors. Repeat the traversal
201695a5a6aSMatthias Springer       // until the region is empty. (The block graph could be disconnected.)
202695a5a6aSMatthias Springer       while (!r.empty()) {
203695a5a6aSMatthias Springer         SmallVector<Block *> erasedBlocks;
20446f65e45SCongcong Cai         // Some blocks may have invalid successor, use a set including nullptr
20546f65e45SCongcong Cai         // to avoid null pointer.
20646f65e45SCongcong Cai         llvm::SmallPtrSet<Block *, 4> visited{nullptr};
20746f65e45SCongcong Cai         for (Block *b : llvm::post_order_ext(&r.front(), visited)) {
208695a5a6aSMatthias Springer           // Visit ops in reverse order.
209695a5a6aSMatthias Springer           for (Operation &op :
210695a5a6aSMatthias Springer                llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
211695a5a6aSMatthias Springer             eraseTree(&op);
212695a5a6aSMatthias Springer           // Do not erase the block immediately. This is not supprted by the
213695a5a6aSMatthias Springer           // post_order iterator.
214695a5a6aSMatthias Springer           erasedBlocks.push_back(b);
215695a5a6aSMatthias Springer         }
216695a5a6aSMatthias Springer         for (Block *b : erasedBlocks) {
217695a5a6aSMatthias Springer           // Explicitly drop all uses in case there is a cycle in the block
218695a5a6aSMatthias Springer           // graph.
219695a5a6aSMatthias Springer           for (BlockArgument bbArg : b->getArguments())
220695a5a6aSMatthias Springer             bbArg.dropAllUses();
221695a5a6aSMatthias Springer           b->dropAllUses();
22262bf7710SMatthias Springer           eraseBlock(b);
223695a5a6aSMatthias Springer         }
224695a5a6aSMatthias Springer       }
225695a5a6aSMatthias Springer     }
226695a5a6aSMatthias Springer     // Then erase the enclosing op.
227695a5a6aSMatthias Springer     eraseSingleOp(op);
228695a5a6aSMatthias Springer   };
229695a5a6aSMatthias Springer 
230695a5a6aSMatthias Springer   eraseTree(op);
231dfe09cc6SRiver Riddle }
232dfe09cc6SRiver Riddle 
eraseBlock(Block * block)233ec10f066SRiver Riddle void RewriterBase::eraseBlock(Block *block) {
23462bf7710SMatthias Springer   assert(block->use_empty() && "expected 'block' to have no uses");
23562bf7710SMatthias Springer 
2363f9cdd44SUday Bondhugula   for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
2373f9cdd44SUday Bondhugula     assert(op.use_empty() && "expected 'op' to have no uses");
2383f9cdd44SUday Bondhugula     eraseOp(&op);
2393f9cdd44SUday Bondhugula   }
24062bf7710SMatthias Springer 
24162bf7710SMatthias Springer   // Notify the listener that the block is about to be removed.
24262bf7710SMatthias Springer   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
243914e6074SMatthias Springer     rewriteListener->notifyBlockErased(block);
24462bf7710SMatthias Springer 
2453f9cdd44SUday Bondhugula   block->erase();
2463f9cdd44SUday Bondhugula }
2473f9cdd44SUday Bondhugula 
finalizeOpModification(Operation * op)2485fcf907bSMatthias Springer void RewriterBase::finalizeOpModification(Operation *op) {
249bafc4dfcSMatthias Springer   // Notify the listener that the operation was modified.
250bafc4dfcSMatthias Springer   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
251bafc4dfcSMatthias Springer     rewriteListener->notifyOperationModified(op);
252bafc4dfcSMatthias Springer }
253bafc4dfcSMatthias Springer 
replaceAllUsesExcept(Value from,Value to,const SmallPtrSetImpl<Operation * > & preservedUsers)254*5aeb604cSMaheshRavishankar void RewriterBase::replaceAllUsesExcept(
255*5aeb604cSMaheshRavishankar     Value from, Value to, const SmallPtrSetImpl<Operation *> &preservedUsers) {
256*5aeb604cSMaheshRavishankar   return replaceUsesWithIf(from, to, [&](OpOperand &use) {
257*5aeb604cSMaheshRavishankar     Operation *user = use.getOwner();
258*5aeb604cSMaheshRavishankar     return !preservedUsers.contains(user);
259*5aeb604cSMaheshRavishankar   });
260*5aeb604cSMaheshRavishankar }
261*5aeb604cSMaheshRavishankar 
replaceUsesWithIf(Value from,Value to,function_ref<bool (OpOperand &)> functor,bool * allUsesReplaced)26221f4b84cSMatthias Springer void RewriterBase::replaceUsesWithIf(Value from, Value to,
26359a92019SMatthias Springer                                      function_ref<bool(OpOperand &)> functor,
26459a92019SMatthias Springer                                      bool *allUsesReplaced) {
26559a92019SMatthias Springer   bool allReplaced = true;
26677603e28SDiego Caballero   for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
26759a92019SMatthias Springer     bool replace = functor(operand);
26859a92019SMatthias Springer     if (replace)
2695fcf907bSMatthias Springer       modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
27059a92019SMatthias Springer     allReplaced &= replace;
27177603e28SDiego Caballero   }
27259a92019SMatthias Springer   if (allUsesReplaced)
27359a92019SMatthias Springer     *allUsesReplaced = allReplaced;
27459a92019SMatthias Springer }
27559a92019SMatthias Springer 
replaceUsesWithIf(ValueRange from,ValueRange to,function_ref<bool (OpOperand &)> functor,bool * allUsesReplaced)27659a92019SMatthias Springer void RewriterBase::replaceUsesWithIf(ValueRange from, ValueRange to,
27759a92019SMatthias Springer                                      function_ref<bool(OpOperand &)> functor,
27859a92019SMatthias Springer                                      bool *allUsesReplaced) {
27959a92019SMatthias Springer   assert(from.size() == to.size() && "incorrect number of replacements");
28059a92019SMatthias Springer   bool allReplaced = true;
28159a92019SMatthias Springer   for (auto it : llvm::zip_equal(from, to)) {
28259a92019SMatthias Springer     bool r;
28359a92019SMatthias Springer     replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
28459a92019SMatthias Springer                       /*allUsesReplaced=*/&r);
28559a92019SMatthias Springer     allReplaced &= r;
28659a92019SMatthias Springer   }
28759a92019SMatthias Springer   if (allUsesReplaced)
28859a92019SMatthias Springer     *allUsesReplaced = allReplaced;
28977603e28SDiego Caballero }
29077603e28SDiego Caballero 
inlineBlockBefore(Block * source,Block * dest,Block::iterator before,ValueRange argValues)29142c31d83SMatthias Springer void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
29242c31d83SMatthias Springer                                      Block::iterator before,
2939c7b0c4aSRahul Joshi                                      ValueRange argValues) {
29442c31d83SMatthias Springer   assert(argValues.size() == source->getNumArguments() &&
29542c31d83SMatthias Springer          "incorrect # of argument replacement values");
29642c31d83SMatthias Springer 
29742c31d83SMatthias Springer   // The source block will be deleted, so it should not have any users (i.e.,
29842c31d83SMatthias Springer   // there should be no predecessors).
2999c7b0c4aSRahul Joshi   assert(source->hasNoPredecessors() &&
3009c7b0c4aSRahul Joshi          "expected 'source' to have no predecessors");
30142c31d83SMatthias Springer 
30242c31d83SMatthias Springer   if (dest->end() != before) {
30342c31d83SMatthias Springer     // The source block will be inserted in the middle of the dest block, so
30442c31d83SMatthias Springer     // the source block should have no successors. Otherwise, the remainder of
30542c31d83SMatthias Springer     // the dest block would be unreachable.
3069c7b0c4aSRahul Joshi     assert(source->hasNoSuccessors() &&
3079c7b0c4aSRahul Joshi            "expected 'source' to have no successors");
30842c31d83SMatthias Springer   } else {
30942c31d83SMatthias Springer     // The source block will be inserted at the end of the dest block, so the
31042c31d83SMatthias Springer     // dest block should have no successors. Otherwise, the inserted operations
31142c31d83SMatthias Springer     // will be unreachable.
31242c31d83SMatthias Springer     assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
31342c31d83SMatthias Springer   }
3149c7b0c4aSRahul Joshi 
31542c31d83SMatthias Springer   // Replace all of the successor arguments with the provided values.
31642c31d83SMatthias Springer   for (auto it : llvm::zip(source->getArguments(), argValues))
31742c31d83SMatthias Springer     replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
3189c7b0c4aSRahul Joshi 
31942c31d83SMatthias Springer   // Move operations from the source block to the dest block and erase the
32042c31d83SMatthias Springer   // source block.
321c672b342SMatthias Springer   if (!listener) {
322c672b342SMatthias Springer     // Fast path: If no listener is attached, move all operations at once.
32342c31d83SMatthias Springer     dest->getOperations().splice(before, source->getOperations());
324c672b342SMatthias Springer   } else {
325c672b342SMatthias Springer     while (!source->empty())
326c672b342SMatthias Springer       moveOpBefore(&source->front(), dest, before);
327c672b342SMatthias Springer   }
328c672b342SMatthias Springer 
329c672b342SMatthias Springer   // Erase the source block.
330c672b342SMatthias Springer   assert(source->empty() && "expected 'source' to be empty");
33162bf7710SMatthias Springer   eraseBlock(source);
33242c31d83SMatthias Springer }
3339c7b0c4aSRahul Joshi 
inlineBlockBefore(Block * source,Operation * op,ValueRange argValues)33442c31d83SMatthias Springer void RewriterBase::inlineBlockBefore(Block *source, Operation *op,
33542c31d83SMatthias Springer                                      ValueRange argValues) {
33642c31d83SMatthias Springer   inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
33742c31d83SMatthias Springer }
33842c31d83SMatthias Springer 
mergeBlocks(Block * source,Block * dest,ValueRange argValues)33942c31d83SMatthias Springer void RewriterBase::mergeBlocks(Block *source, Block *dest,
34042c31d83SMatthias Springer                                ValueRange argValues) {
34142c31d83SMatthias Springer   inlineBlockBefore(source, dest, dest->end(), argValues);
3429c7b0c4aSRahul Joshi }
3439c7b0c4aSRahul Joshi 
3442366561aSRiver Riddle /// Split the operations starting at "before" (inclusive) out of the given
3452366561aSRiver Riddle /// block into a new block, and return it.
splitBlock(Block * block,Block::iterator before)346ec10f066SRiver Riddle Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
347c2675ba9SMatthias Springer   // Fast path: If no listener is attached, split the block directly.
348c2675ba9SMatthias Springer   if (!listener)
3492366561aSRiver Riddle     return block->splitBlock(before);
350c2675ba9SMatthias Springer 
351c2675ba9SMatthias Springer   // `createBlock` sets the insertion point at the beginning of the new block.
352c2675ba9SMatthias Springer   InsertionGuard g(*this);
353c2675ba9SMatthias Springer   Block *newBlock =
354c2675ba9SMatthias Springer       createBlock(block->getParent(), std::next(block->getIterator()));
355c2675ba9SMatthias Springer 
356c2675ba9SMatthias Springer   // If `before` points to end of the block, no ops should be moved.
357c2675ba9SMatthias Springer   if (before == block->end())
358c2675ba9SMatthias Springer     return newBlock;
359c2675ba9SMatthias Springer 
360c2675ba9SMatthias Springer   // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
361c2675ba9SMatthias Springer   // Stop when the operation pointed to by `before` has been moved.
362c2675ba9SMatthias Springer   while (before->getBlock() != newBlock)
363c2675ba9SMatthias Springer     moveOpBefore(&block->back(), newBlock, newBlock->begin());
364c2675ba9SMatthias Springer 
365c2675ba9SMatthias Springer   return newBlock;
3662366561aSRiver Riddle }
3672366561aSRiver Riddle 
3688ad35b90SAlex Zinenko /// Move the blocks that belong to "region" before the given position in
3698ad35b90SAlex Zinenko /// another region.  The two regions must be different.  The caller is in
3708ad35b90SAlex Zinenko /// charge to update create the operation transferring the control flow to the
3718ad35b90SAlex Zinenko /// region and pass it the correct block arguments.
inlineRegionBefore(Region & region,Region & parent,Region::iterator before)372ec10f066SRiver Riddle void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
3738ad35b90SAlex Zinenko                                       Region::iterator before) {
3743ed98cb3SMatthias Springer   // Fast path: If no listener is attached, move all blocks at once.
3753ed98cb3SMatthias Springer   if (!listener) {
3763e99d995SRiver Riddle     parent.getBlocks().splice(before, region.getBlocks());
3773ed98cb3SMatthias Springer     return;
3783ed98cb3SMatthias Springer   }
3793ed98cb3SMatthias Springer 
3803ed98cb3SMatthias Springer   // Move blocks from the beginning of the region one-by-one.
381da784a25SMatthias Springer   while (!region.empty())
382da784a25SMatthias Springer     moveBlockBefore(&region.front(), &parent, before);
3833e99d995SRiver Riddle }
inlineRegionBefore(Region & region,Block * before)384ec10f066SRiver Riddle void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
3853e99d995SRiver Riddle   inlineRegionBefore(region, *before->getParent(), before->getIterator());
3868ad35b90SAlex Zinenko }
3878ad35b90SAlex Zinenko 
moveBlockBefore(Block * block,Block * anotherBlock)388da784a25SMatthias Springer void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
389da784a25SMatthias Springer   moveBlockBefore(block, anotherBlock->getParent(),
390da784a25SMatthias Springer                   anotherBlock->getIterator());
391da784a25SMatthias Springer }
392da784a25SMatthias Springer 
moveBlockBefore(Block * block,Region * region,Region::iterator iterator)393da784a25SMatthias Springer void RewriterBase::moveBlockBefore(Block *block, Region *region,
394da784a25SMatthias Springer                                    Region::iterator iterator) {
395da784a25SMatthias Springer   Region *currentRegion = block->getParent();
396da784a25SMatthias Springer   Region::iterator nextIterator = std::next(block->getIterator());
397da784a25SMatthias Springer   block->moveBefore(region, iterator);
398da784a25SMatthias Springer   if (listener)
399da784a25SMatthias Springer     listener->notifyBlockInserted(block, /*previous=*/currentRegion,
400da784a25SMatthias Springer                                   /*previousIt=*/nextIterator);
401da784a25SMatthias Springer }
402da784a25SMatthias Springer 
moveOpBefore(Operation * op,Operation * existingOp)4035cc0f76dSMatthias Springer void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
4045cc0f76dSMatthias Springer   moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
4055cc0f76dSMatthias Springer }
4065cc0f76dSMatthias Springer 
moveOpBefore(Operation * op,Block * block,Block::iterator iterator)4075cc0f76dSMatthias Springer void RewriterBase::moveOpBefore(Operation *op, Block *block,
4085cc0f76dSMatthias Springer                                 Block::iterator iterator) {
4095cc0f76dSMatthias Springer   Block *currentBlock = op->getBlock();
410da784a25SMatthias Springer   Block::iterator nextIterator = std::next(op->getIterator());
4115cc0f76dSMatthias Springer   op->moveBefore(block, iterator);
4125cc0f76dSMatthias Springer   if (listener)
4135cc0f76dSMatthias Springer     listener->notifyOperationInserted(
414da784a25SMatthias Springer         op, /*previous=*/InsertPoint(currentBlock, nextIterator));
4155cc0f76dSMatthias Springer }
4165cc0f76dSMatthias Springer 
moveOpAfter(Operation * op,Operation * existingOp)4175cc0f76dSMatthias Springer void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
4185cc0f76dSMatthias Springer   moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
4195cc0f76dSMatthias Springer }
4205cc0f76dSMatthias Springer 
moveOpAfter(Operation * op,Block * block,Block::iterator iterator)4215cc0f76dSMatthias Springer void RewriterBase::moveOpAfter(Operation *op, Block *block,
4225cc0f76dSMatthias Springer                                Block::iterator iterator) {
423da784a25SMatthias Springer   assert(iterator != block->end() && "cannot move after end of block");
424da784a25SMatthias Springer   moveOpBefore(op, block, std::next(iterator));
4255cc0f76dSMatthias Springer }
426