17de0da95SChris Lattner //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
27de0da95SChris Lattner //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67de0da95SChris Lattner //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
87de0da95SChris Lattner
97de0da95SChris Lattner #include "mlir/IR/PatternMatch.h"
106ae7f66fSJacques Pienaar #include "mlir/Config/mlir-config.h"
114d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
12695a5a6aSMatthias Springer #include "mlir/IR/Iterators.h"
13695a5a6aSMatthias Springer #include "mlir/IR/RegionKindInterface.h"
14*5aeb604cSMaheshRavishankar #include "llvm/ADT/SmallPtrSet.h"
156fb3d597SDiego Caballero
167de0da95SChris Lattner using namespace mlir;
177de0da95SChris Lattner
18b99bd771SRiver Riddle //===----------------------------------------------------------------------===//
19b99bd771SRiver Riddle // PatternBenefit
20b99bd771SRiver Riddle //===----------------------------------------------------------------------===//
21b99bd771SRiver Riddle
PatternBenefit(unsigned benefit)227de0da95SChris Lattner PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
237de0da95SChris Lattner assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
247de0da95SChris Lattner "This pattern match benefit is too large to represent");
257de0da95SChris Lattner }
267de0da95SChris Lattner
getBenefit() const277de0da95SChris Lattner unsigned short PatternBenefit::getBenefit() const {
283e98fbf4SRiver Riddle assert(!isImpossibleToMatch() && "Pattern doesn't match");
297de0da95SChris Lattner return representation;
307de0da95SChris Lattner }
317de0da95SChris Lattner
327de0da95SChris Lattner //===----------------------------------------------------------------------===//
33b99bd771SRiver Riddle // Pattern
347de0da95SChris Lattner //===----------------------------------------------------------------------===//
357de0da95SChris Lattner
3676f3c2f3SRiver Riddle //===----------------------------------------------------------------------===//
3776f3c2f3SRiver Riddle // OperationName Root Constructors
3876f3c2f3SRiver Riddle
Pattern(StringRef rootName,PatternBenefit benefit,MLIRContext * context,ArrayRef<StringRef> generatedNames)3986a5323fSChris Lattner Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
4076f3c2f3SRiver Riddle MLIRContext *context, ArrayRef<StringRef> generatedNames)
4176f3c2f3SRiver Riddle : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
4276f3c2f3SRiver Riddle RootKind::OperationName, generatedNames, benefit, context) {}
4376f3c2f3SRiver Riddle
4476f3c2f3SRiver Riddle //===----------------------------------------------------------------------===//
4576f3c2f3SRiver Riddle // MatchAnyOpTypeTag Root Constructors
4676f3c2f3SRiver Riddle
Pattern(MatchAnyOpTypeTag tag,PatternBenefit benefit,MLIRContext * context,ArrayRef<StringRef> generatedNames)4776f3c2f3SRiver Riddle Pattern::Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit,
4876f3c2f3SRiver Riddle MLIRContext *context, ArrayRef<StringRef> generatedNames)
4976f3c2f3SRiver Riddle : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
5076f3c2f3SRiver Riddle
5176f3c2f3SRiver Riddle //===----------------------------------------------------------------------===//
5276f3c2f3SRiver Riddle // MatchInterfaceOpTypeTag Root Constructors
5376f3c2f3SRiver Riddle
Pattern(MatchInterfaceOpTypeTag tag,TypeID interfaceID,PatternBenefit benefit,MLIRContext * context,ArrayRef<StringRef> generatedNames)5476f3c2f3SRiver Riddle Pattern::Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
5576f3c2f3SRiver Riddle PatternBenefit benefit, MLIRContext *context,
5676f3c2f3SRiver Riddle ArrayRef<StringRef> generatedNames)
5776f3c2f3SRiver Riddle : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
5876f3c2f3SRiver Riddle generatedNames, benefit, context) {}
5976f3c2f3SRiver Riddle
6076f3c2f3SRiver Riddle //===----------------------------------------------------------------------===//
6176f3c2f3SRiver Riddle // MatchTraitOpTypeTag Root Constructors
6276f3c2f3SRiver Riddle
Pattern(MatchTraitOpTypeTag tag,TypeID traitID,PatternBenefit benefit,MLIRContext * context,ArrayRef<StringRef> generatedNames)6376f3c2f3SRiver Riddle Pattern::Pattern(MatchTraitOpTypeTag tag, TypeID traitID,
6476f3c2f3SRiver Riddle PatternBenefit benefit, MLIRContext *context,
6576f3c2f3SRiver Riddle ArrayRef<StringRef> generatedNames)
6676f3c2f3SRiver Riddle : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
6776f3c2f3SRiver Riddle benefit, context) {}
6876f3c2f3SRiver Riddle
6976f3c2f3SRiver Riddle //===----------------------------------------------------------------------===//
7076f3c2f3SRiver Riddle // General Constructors
7176f3c2f3SRiver Riddle
Pattern(const void * rootValue,RootKind rootKind,ArrayRef<StringRef> generatedNames,PatternBenefit benefit,MLIRContext * context)7276f3c2f3SRiver Riddle Pattern::Pattern(const void *rootValue, RootKind rootKind,
7376f3c2f3SRiver Riddle ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
7486a5323fSChris Lattner MLIRContext *context)
7576f3c2f3SRiver Riddle : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
7676f3c2f3SRiver Riddle contextAndHasBoundedRecursion(context, false) {
7776f3c2f3SRiver Riddle if (generatedNames.empty())
7876f3c2f3SRiver Riddle return;
79b99bd771SRiver Riddle generatedOps.reserve(generatedNames.size());
80b99bd771SRiver Riddle std::transform(generatedNames.begin(), generatedNames.end(),
81b99bd771SRiver Riddle std::back_inserter(generatedOps), [context](StringRef name) {
82b99bd771SRiver Riddle return OperationName(name, context);
83b99bd771SRiver Riddle });
84b99bd771SRiver Riddle }
853f2530cdSChris Lattner
863f2530cdSChris Lattner //===----------------------------------------------------------------------===//
87b99bd771SRiver Riddle // RewritePattern
883f2530cdSChris Lattner //===----------------------------------------------------------------------===//
893f2530cdSChris Lattner
rewrite(Operation * op,PatternRewriter & rewriter) const90f9d91531SRiver Riddle void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
915de726f4SRiver Riddle llvm_unreachable("need to implement either matchAndRewrite or one of the "
925de726f4SRiver Riddle "rewrite functions!");
935de726f4SRiver Riddle }
945de726f4SRiver Riddle
match(Operation * op) const953145427dSRiver Riddle LogicalResult RewritePattern::match(Operation *op) const {
965de726f4SRiver Riddle llvm_unreachable("need to implement either match or matchAndRewrite!");
977de0da95SChris Lattner }
987de0da95SChris Lattner
99b99bd771SRiver Riddle /// Out-of-line vtable anchor.
anchor()100b99bd771SRiver Riddle void RewritePattern::anchor() {}
101b99bd771SRiver Riddle
102b99bd771SRiver Riddle //===----------------------------------------------------------------------===//
103ec10f066SRiver Riddle // RewriterBase
104b99bd771SRiver Riddle //===----------------------------------------------------------------------===//
105647f8cabSRiver Riddle
classof(const OpBuilder::Listener * base)106c6532830SMatthias Springer bool RewriterBase::Listener::classof(const OpBuilder::Listener *base) {
107c6532830SMatthias Springer return base->getKind() == OpBuilder::ListenerBase::Kind::RewriterBaseListener;
108c6532830SMatthias Springer }
109c6532830SMatthias Springer
~RewriterBase()110ec10f066SRiver Riddle RewriterBase::~RewriterBase() {
1117de0da95SChris Lattner // Out of line to provide a vtable anchor for the class.
1127de0da95SChris Lattner }
1137de0da95SChris Lattner
replaceAllOpUsesWith(Operation * from,ValueRange to)11438113a08SMatthias Springer void RewriterBase::replaceAllOpUsesWith(Operation *from, ValueRange to) {
11538113a08SMatthias Springer // Notify the listener that we're about to replace this op.
11638113a08SMatthias Springer if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
11738113a08SMatthias Springer rewriteListener->notifyOperationReplaced(from, to);
11838113a08SMatthias Springer
11938113a08SMatthias Springer replaceAllUsesWith(from->getResults(), to);
12038113a08SMatthias Springer }
12138113a08SMatthias Springer
replaceAllOpUsesWith(Operation * from,Operation * to)12238113a08SMatthias Springer void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) {
12338113a08SMatthias Springer // Notify the listener that we're about to replace this op.
12438113a08SMatthias Springer if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
12538113a08SMatthias Springer rewriteListener->notifyOperationReplaced(from, to);
12638113a08SMatthias Springer
12738113a08SMatthias Springer replaceAllUsesWith(from->getResults(), to->getResults());
12838113a08SMatthias Springer }
12938113a08SMatthias Springer
130ec10f066SRiver Riddle /// This method replaces the results of the operation with the specified list of
131ec10f066SRiver Riddle /// values. The number of provided values must match the number of results of
13271d50c89SMatthias Springer /// the operation. The replaced op is erased.
replaceOp(Operation * op,ValueRange newValues)133ec10f066SRiver Riddle void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
13421f4b84cSMatthias Springer assert(op->getNumResults() == newValues.size() &&
13521f4b84cSMatthias Springer "incorrect # of replacement values");
13621f4b84cSMatthias Springer
13759a92019SMatthias Springer // Replace all result uses. Also notifies the listener of modifications.
138f1aa7837SMatthias Springer replaceAllOpUsesWith(op, newValues);
1391427d0f0SChris Lattner
140695a5a6aSMatthias Springer // Erase op and notify listener.
14171d50c89SMatthias Springer eraseOp(op);
14271d50c89SMatthias Springer }
14371d50c89SMatthias Springer
14471d50c89SMatthias Springer /// This method replaces the results of the operation with the specified new op
14571d50c89SMatthias Springer /// (replacement). The number of results of the two operations must match. The
14671d50c89SMatthias Springer /// replaced op is erased.
replaceOp(Operation * op,Operation * newOp)14771d50c89SMatthias Springer void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
14871d50c89SMatthias Springer assert(op && newOp && "expected non-null op");
14971d50c89SMatthias Springer assert(op->getNumResults() == newOp->getNumResults() &&
15071d50c89SMatthias Springer "ops have different number of results");
15171d50c89SMatthias Springer
15259a92019SMatthias Springer // Replace all result uses. Also notifies the listener of modifications.
153f1aa7837SMatthias Springer replaceAllOpUsesWith(op, newOp->getResults());
15471d50c89SMatthias Springer
155695a5a6aSMatthias Springer // Erase op and notify listener.
15671d50c89SMatthias Springer eraseOp(op);
1571427d0f0SChris Lattner }
1581427d0f0SChris Lattner
159dfe09cc6SRiver Riddle /// This method erases an operation that is known to have no uses. The uses of
160dfe09cc6SRiver Riddle /// the given operation *must* be known to be dead.
eraseOp(Operation * op)161ec10f066SRiver Riddle void RewriterBase::eraseOp(Operation *op) {
162dfe09cc6SRiver Riddle assert(op->use_empty() && "expected 'op' to have no uses");
163695a5a6aSMatthias Springer auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
164695a5a6aSMatthias Springer
165695a5a6aSMatthias Springer // Fast path: If no listener is attached, the op can be dropped in one go.
166695a5a6aSMatthias Springer if (!rewriteListener) {
167dfe09cc6SRiver Riddle op->erase();
168695a5a6aSMatthias Springer return;
169695a5a6aSMatthias Springer }
170695a5a6aSMatthias Springer
171695a5a6aSMatthias Springer // Helper function that erases a single op.
172695a5a6aSMatthias Springer auto eraseSingleOp = [&](Operation *op) {
173695a5a6aSMatthias Springer #ifndef NDEBUG
174695a5a6aSMatthias Springer // All nested ops should have been erased already.
175695a5a6aSMatthias Springer assert(
176695a5a6aSMatthias Springer llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
177695a5a6aSMatthias Springer "expected empty regions");
178695a5a6aSMatthias Springer // All users should have been erased already if the op is in a region with
179695a5a6aSMatthias Springer // SSA dominance.
180695a5a6aSMatthias Springer if (!op->use_empty() && op->getParentOp())
181695a5a6aSMatthias Springer assert(mayBeGraphRegion(*op->getParentRegion()) &&
182695a5a6aSMatthias Springer "expected that op has no uses");
183695a5a6aSMatthias Springer #endif // NDEBUG
184914e6074SMatthias Springer rewriteListener->notifyOperationErased(op);
185695a5a6aSMatthias Springer
186695a5a6aSMatthias Springer // Explicitly drop all uses in case the op is in a graph region.
187695a5a6aSMatthias Springer op->dropAllUses();
188695a5a6aSMatthias Springer op->erase();
189695a5a6aSMatthias Springer };
190695a5a6aSMatthias Springer
191695a5a6aSMatthias Springer // Nested ops must be erased one-by-one, so that listeners have a consistent
192695a5a6aSMatthias Springer // view of the IR every time a notification is triggered. Users must be
193695a5a6aSMatthias Springer // erased before definitions. I.e., post-order, reverse dominance.
194695a5a6aSMatthias Springer std::function<void(Operation *)> eraseTree = [&](Operation *op) {
195695a5a6aSMatthias Springer // Erase nested ops.
196695a5a6aSMatthias Springer for (Region &r : llvm::reverse(op->getRegions())) {
197695a5a6aSMatthias Springer // Erase all blocks in the right order. Successors should be erased
198695a5a6aSMatthias Springer // before predecessors because successor blocks may use values defined
199695a5a6aSMatthias Springer // in predecessor blocks. A post-order traversal of blocks within a
200695a5a6aSMatthias Springer // region visits successors before predecessors. Repeat the traversal
201695a5a6aSMatthias Springer // until the region is empty. (The block graph could be disconnected.)
202695a5a6aSMatthias Springer while (!r.empty()) {
203695a5a6aSMatthias Springer SmallVector<Block *> erasedBlocks;
20446f65e45SCongcong Cai // Some blocks may have invalid successor, use a set including nullptr
20546f65e45SCongcong Cai // to avoid null pointer.
20646f65e45SCongcong Cai llvm::SmallPtrSet<Block *, 4> visited{nullptr};
20746f65e45SCongcong Cai for (Block *b : llvm::post_order_ext(&r.front(), visited)) {
208695a5a6aSMatthias Springer // Visit ops in reverse order.
209695a5a6aSMatthias Springer for (Operation &op :
210695a5a6aSMatthias Springer llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
211695a5a6aSMatthias Springer eraseTree(&op);
212695a5a6aSMatthias Springer // Do not erase the block immediately. This is not supprted by the
213695a5a6aSMatthias Springer // post_order iterator.
214695a5a6aSMatthias Springer erasedBlocks.push_back(b);
215695a5a6aSMatthias Springer }
216695a5a6aSMatthias Springer for (Block *b : erasedBlocks) {
217695a5a6aSMatthias Springer // Explicitly drop all uses in case there is a cycle in the block
218695a5a6aSMatthias Springer // graph.
219695a5a6aSMatthias Springer for (BlockArgument bbArg : b->getArguments())
220695a5a6aSMatthias Springer bbArg.dropAllUses();
221695a5a6aSMatthias Springer b->dropAllUses();
22262bf7710SMatthias Springer eraseBlock(b);
223695a5a6aSMatthias Springer }
224695a5a6aSMatthias Springer }
225695a5a6aSMatthias Springer }
226695a5a6aSMatthias Springer // Then erase the enclosing op.
227695a5a6aSMatthias Springer eraseSingleOp(op);
228695a5a6aSMatthias Springer };
229695a5a6aSMatthias Springer
230695a5a6aSMatthias Springer eraseTree(op);
231dfe09cc6SRiver Riddle }
232dfe09cc6SRiver Riddle
eraseBlock(Block * block)233ec10f066SRiver Riddle void RewriterBase::eraseBlock(Block *block) {
23462bf7710SMatthias Springer assert(block->use_empty() && "expected 'block' to have no uses");
23562bf7710SMatthias Springer
2363f9cdd44SUday Bondhugula for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
2373f9cdd44SUday Bondhugula assert(op.use_empty() && "expected 'op' to have no uses");
2383f9cdd44SUday Bondhugula eraseOp(&op);
2393f9cdd44SUday Bondhugula }
24062bf7710SMatthias Springer
24162bf7710SMatthias Springer // Notify the listener that the block is about to be removed.
24262bf7710SMatthias Springer if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
243914e6074SMatthias Springer rewriteListener->notifyBlockErased(block);
24462bf7710SMatthias Springer
2453f9cdd44SUday Bondhugula block->erase();
2463f9cdd44SUday Bondhugula }
2473f9cdd44SUday Bondhugula
finalizeOpModification(Operation * op)2485fcf907bSMatthias Springer void RewriterBase::finalizeOpModification(Operation *op) {
249bafc4dfcSMatthias Springer // Notify the listener that the operation was modified.
250bafc4dfcSMatthias Springer if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
251bafc4dfcSMatthias Springer rewriteListener->notifyOperationModified(op);
252bafc4dfcSMatthias Springer }
253bafc4dfcSMatthias Springer
replaceAllUsesExcept(Value from,Value to,const SmallPtrSetImpl<Operation * > & preservedUsers)254*5aeb604cSMaheshRavishankar void RewriterBase::replaceAllUsesExcept(
255*5aeb604cSMaheshRavishankar Value from, Value to, const SmallPtrSetImpl<Operation *> &preservedUsers) {
256*5aeb604cSMaheshRavishankar return replaceUsesWithIf(from, to, [&](OpOperand &use) {
257*5aeb604cSMaheshRavishankar Operation *user = use.getOwner();
258*5aeb604cSMaheshRavishankar return !preservedUsers.contains(user);
259*5aeb604cSMaheshRavishankar });
260*5aeb604cSMaheshRavishankar }
261*5aeb604cSMaheshRavishankar
replaceUsesWithIf(Value from,Value to,function_ref<bool (OpOperand &)> functor,bool * allUsesReplaced)26221f4b84cSMatthias Springer void RewriterBase::replaceUsesWithIf(Value from, Value to,
26359a92019SMatthias Springer function_ref<bool(OpOperand &)> functor,
26459a92019SMatthias Springer bool *allUsesReplaced) {
26559a92019SMatthias Springer bool allReplaced = true;
26677603e28SDiego Caballero for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
26759a92019SMatthias Springer bool replace = functor(operand);
26859a92019SMatthias Springer if (replace)
2695fcf907bSMatthias Springer modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
27059a92019SMatthias Springer allReplaced &= replace;
27177603e28SDiego Caballero }
27259a92019SMatthias Springer if (allUsesReplaced)
27359a92019SMatthias Springer *allUsesReplaced = allReplaced;
27459a92019SMatthias Springer }
27559a92019SMatthias Springer
replaceUsesWithIf(ValueRange from,ValueRange to,function_ref<bool (OpOperand &)> functor,bool * allUsesReplaced)27659a92019SMatthias Springer void RewriterBase::replaceUsesWithIf(ValueRange from, ValueRange to,
27759a92019SMatthias Springer function_ref<bool(OpOperand &)> functor,
27859a92019SMatthias Springer bool *allUsesReplaced) {
27959a92019SMatthias Springer assert(from.size() == to.size() && "incorrect number of replacements");
28059a92019SMatthias Springer bool allReplaced = true;
28159a92019SMatthias Springer for (auto it : llvm::zip_equal(from, to)) {
28259a92019SMatthias Springer bool r;
28359a92019SMatthias Springer replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
28459a92019SMatthias Springer /*allUsesReplaced=*/&r);
28559a92019SMatthias Springer allReplaced &= r;
28659a92019SMatthias Springer }
28759a92019SMatthias Springer if (allUsesReplaced)
28859a92019SMatthias Springer *allUsesReplaced = allReplaced;
28977603e28SDiego Caballero }
29077603e28SDiego Caballero
inlineBlockBefore(Block * source,Block * dest,Block::iterator before,ValueRange argValues)29142c31d83SMatthias Springer void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
29242c31d83SMatthias Springer Block::iterator before,
2939c7b0c4aSRahul Joshi ValueRange argValues) {
29442c31d83SMatthias Springer assert(argValues.size() == source->getNumArguments() &&
29542c31d83SMatthias Springer "incorrect # of argument replacement values");
29642c31d83SMatthias Springer
29742c31d83SMatthias Springer // The source block will be deleted, so it should not have any users (i.e.,
29842c31d83SMatthias Springer // there should be no predecessors).
2999c7b0c4aSRahul Joshi assert(source->hasNoPredecessors() &&
3009c7b0c4aSRahul Joshi "expected 'source' to have no predecessors");
30142c31d83SMatthias Springer
30242c31d83SMatthias Springer if (dest->end() != before) {
30342c31d83SMatthias Springer // The source block will be inserted in the middle of the dest block, so
30442c31d83SMatthias Springer // the source block should have no successors. Otherwise, the remainder of
30542c31d83SMatthias Springer // the dest block would be unreachable.
3069c7b0c4aSRahul Joshi assert(source->hasNoSuccessors() &&
3079c7b0c4aSRahul Joshi "expected 'source' to have no successors");
30842c31d83SMatthias Springer } else {
30942c31d83SMatthias Springer // The source block will be inserted at the end of the dest block, so the
31042c31d83SMatthias Springer // dest block should have no successors. Otherwise, the inserted operations
31142c31d83SMatthias Springer // will be unreachable.
31242c31d83SMatthias Springer assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
31342c31d83SMatthias Springer }
3149c7b0c4aSRahul Joshi
31542c31d83SMatthias Springer // Replace all of the successor arguments with the provided values.
31642c31d83SMatthias Springer for (auto it : llvm::zip(source->getArguments(), argValues))
31742c31d83SMatthias Springer replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
3189c7b0c4aSRahul Joshi
31942c31d83SMatthias Springer // Move operations from the source block to the dest block and erase the
32042c31d83SMatthias Springer // source block.
321c672b342SMatthias Springer if (!listener) {
322c672b342SMatthias Springer // Fast path: If no listener is attached, move all operations at once.
32342c31d83SMatthias Springer dest->getOperations().splice(before, source->getOperations());
324c672b342SMatthias Springer } else {
325c672b342SMatthias Springer while (!source->empty())
326c672b342SMatthias Springer moveOpBefore(&source->front(), dest, before);
327c672b342SMatthias Springer }
328c672b342SMatthias Springer
329c672b342SMatthias Springer // Erase the source block.
330c672b342SMatthias Springer assert(source->empty() && "expected 'source' to be empty");
33162bf7710SMatthias Springer eraseBlock(source);
33242c31d83SMatthias Springer }
3339c7b0c4aSRahul Joshi
inlineBlockBefore(Block * source,Operation * op,ValueRange argValues)33442c31d83SMatthias Springer void RewriterBase::inlineBlockBefore(Block *source, Operation *op,
33542c31d83SMatthias Springer ValueRange argValues) {
33642c31d83SMatthias Springer inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
33742c31d83SMatthias Springer }
33842c31d83SMatthias Springer
mergeBlocks(Block * source,Block * dest,ValueRange argValues)33942c31d83SMatthias Springer void RewriterBase::mergeBlocks(Block *source, Block *dest,
34042c31d83SMatthias Springer ValueRange argValues) {
34142c31d83SMatthias Springer inlineBlockBefore(source, dest, dest->end(), argValues);
3429c7b0c4aSRahul Joshi }
3439c7b0c4aSRahul Joshi
3442366561aSRiver Riddle /// Split the operations starting at "before" (inclusive) out of the given
3452366561aSRiver Riddle /// block into a new block, and return it.
splitBlock(Block * block,Block::iterator before)346ec10f066SRiver Riddle Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
347c2675ba9SMatthias Springer // Fast path: If no listener is attached, split the block directly.
348c2675ba9SMatthias Springer if (!listener)
3492366561aSRiver Riddle return block->splitBlock(before);
350c2675ba9SMatthias Springer
351c2675ba9SMatthias Springer // `createBlock` sets the insertion point at the beginning of the new block.
352c2675ba9SMatthias Springer InsertionGuard g(*this);
353c2675ba9SMatthias Springer Block *newBlock =
354c2675ba9SMatthias Springer createBlock(block->getParent(), std::next(block->getIterator()));
355c2675ba9SMatthias Springer
356c2675ba9SMatthias Springer // If `before` points to end of the block, no ops should be moved.
357c2675ba9SMatthias Springer if (before == block->end())
358c2675ba9SMatthias Springer return newBlock;
359c2675ba9SMatthias Springer
360c2675ba9SMatthias Springer // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
361c2675ba9SMatthias Springer // Stop when the operation pointed to by `before` has been moved.
362c2675ba9SMatthias Springer while (before->getBlock() != newBlock)
363c2675ba9SMatthias Springer moveOpBefore(&block->back(), newBlock, newBlock->begin());
364c2675ba9SMatthias Springer
365c2675ba9SMatthias Springer return newBlock;
3662366561aSRiver Riddle }
3672366561aSRiver Riddle
3688ad35b90SAlex Zinenko /// Move the blocks that belong to "region" before the given position in
3698ad35b90SAlex Zinenko /// another region. The two regions must be different. The caller is in
3708ad35b90SAlex Zinenko /// charge to update create the operation transferring the control flow to the
3718ad35b90SAlex Zinenko /// region and pass it the correct block arguments.
inlineRegionBefore(Region & region,Region & parent,Region::iterator before)372ec10f066SRiver Riddle void RewriterBase::inlineRegionBefore(Region ®ion, Region &parent,
3738ad35b90SAlex Zinenko Region::iterator before) {
3743ed98cb3SMatthias Springer // Fast path: If no listener is attached, move all blocks at once.
3753ed98cb3SMatthias Springer if (!listener) {
3763e99d995SRiver Riddle parent.getBlocks().splice(before, region.getBlocks());
3773ed98cb3SMatthias Springer return;
3783ed98cb3SMatthias Springer }
3793ed98cb3SMatthias Springer
3803ed98cb3SMatthias Springer // Move blocks from the beginning of the region one-by-one.
381da784a25SMatthias Springer while (!region.empty())
382da784a25SMatthias Springer moveBlockBefore(®ion.front(), &parent, before);
3833e99d995SRiver Riddle }
inlineRegionBefore(Region & region,Block * before)384ec10f066SRiver Riddle void RewriterBase::inlineRegionBefore(Region ®ion, Block *before) {
3853e99d995SRiver Riddle inlineRegionBefore(region, *before->getParent(), before->getIterator());
3868ad35b90SAlex Zinenko }
3878ad35b90SAlex Zinenko
moveBlockBefore(Block * block,Block * anotherBlock)388da784a25SMatthias Springer void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
389da784a25SMatthias Springer moveBlockBefore(block, anotherBlock->getParent(),
390da784a25SMatthias Springer anotherBlock->getIterator());
391da784a25SMatthias Springer }
392da784a25SMatthias Springer
moveBlockBefore(Block * block,Region * region,Region::iterator iterator)393da784a25SMatthias Springer void RewriterBase::moveBlockBefore(Block *block, Region *region,
394da784a25SMatthias Springer Region::iterator iterator) {
395da784a25SMatthias Springer Region *currentRegion = block->getParent();
396da784a25SMatthias Springer Region::iterator nextIterator = std::next(block->getIterator());
397da784a25SMatthias Springer block->moveBefore(region, iterator);
398da784a25SMatthias Springer if (listener)
399da784a25SMatthias Springer listener->notifyBlockInserted(block, /*previous=*/currentRegion,
400da784a25SMatthias Springer /*previousIt=*/nextIterator);
401da784a25SMatthias Springer }
402da784a25SMatthias Springer
moveOpBefore(Operation * op,Operation * existingOp)4035cc0f76dSMatthias Springer void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
4045cc0f76dSMatthias Springer moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
4055cc0f76dSMatthias Springer }
4065cc0f76dSMatthias Springer
moveOpBefore(Operation * op,Block * block,Block::iterator iterator)4075cc0f76dSMatthias Springer void RewriterBase::moveOpBefore(Operation *op, Block *block,
4085cc0f76dSMatthias Springer Block::iterator iterator) {
4095cc0f76dSMatthias Springer Block *currentBlock = op->getBlock();
410da784a25SMatthias Springer Block::iterator nextIterator = std::next(op->getIterator());
4115cc0f76dSMatthias Springer op->moveBefore(block, iterator);
4125cc0f76dSMatthias Springer if (listener)
4135cc0f76dSMatthias Springer listener->notifyOperationInserted(
414da784a25SMatthias Springer op, /*previous=*/InsertPoint(currentBlock, nextIterator));
4155cc0f76dSMatthias Springer }
4165cc0f76dSMatthias Springer
moveOpAfter(Operation * op,Operation * existingOp)4175cc0f76dSMatthias Springer void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
4185cc0f76dSMatthias Springer moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
4195cc0f76dSMatthias Springer }
4205cc0f76dSMatthias Springer
moveOpAfter(Operation * op,Block * block,Block::iterator iterator)4215cc0f76dSMatthias Springer void RewriterBase::moveOpAfter(Operation *op, Block *block,
4225cc0f76dSMatthias Springer Block::iterator iterator) {
423da784a25SMatthias Springer assert(iterator != block->end() && "cannot move after end of block");
424da784a25SMatthias Springer moveOpBefore(op, block, std::next(iterator));
4255cc0f76dSMatthias Springer }
426