xref: /llvm-project/mlir/lib/Transforms/Utils/RegionUtils.cpp (revision 1e5f32e81f96af45551dafb369279c6d55ac9b97)
1ee6f84aeSAlex Zinenko //===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2ee6f84aeSAlex Zinenko //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ee6f84aeSAlex Zinenko //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
8ee6f84aeSAlex Zinenko 
9ee6f84aeSAlex Zinenko #include "mlir/Transforms/RegionUtils.h"
10b00e0c16SChristian Ulmann #include "mlir/Analysis/TopologicalSortUtils.h"
11ee6f84aeSAlex Zinenko #include "mlir/IR/Block.h"
12441b672bSGiuseppe Rossini #include "mlir/IR/BuiltinOps.h"
13da784e77SMahesh Ravishankar #include "mlir/IR/IRMapping.h"
14ee6f84aeSAlex Zinenko #include "mlir/IR/Operation.h"
15d75a611aSRiver Riddle #include "mlir/IR/PatternMatch.h"
16fafb708bSRiver Riddle #include "mlir/IR/RegionGraphTraits.h"
17ee6f84aeSAlex Zinenko #include "mlir/IR/Value.h"
187ce1e7abSRiver Riddle #include "mlir/Interfaces/ControlFlowInterfaces.h"
19eb623ae8SStephen Neuendorffer #include "mlir/Interfaces/SideEffectInterfaces.h"
20441b672bSGiuseppe Rossini #include "mlir/Support/LogicalResult.h"
21ee6f84aeSAlex Zinenko 
22fafb708bSRiver Riddle #include "llvm/ADT/DepthFirstIterator.h"
23fafb708bSRiver Riddle #include "llvm/ADT/PostOrderIterator.h"
24441b672bSGiuseppe Rossini #include "llvm/ADT/STLExtras.h"
25441b672bSGiuseppe Rossini #include "llvm/ADT/SmallSet.h"
264291ae74SAlex Zinenko 
27da784e77SMahesh Ravishankar #include <deque>
28441b672bSGiuseppe Rossini #include <iterator>
29da784e77SMahesh Ravishankar 
30ee6f84aeSAlex Zinenko using namespace mlir;
31ee6f84aeSAlex Zinenko 
32e62a6956SRiver Riddle void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
33ee6f84aeSAlex Zinenko                                       Region &region) {
342bdf33ccSRiver Riddle   for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
351e429540SRiver Riddle     if (region.isAncestor(use.getOwner()->getParentRegion()))
36ee6f84aeSAlex Zinenko       use.set(replacement);
37ee6f84aeSAlex Zinenko   }
38ee6f84aeSAlex Zinenko }
394291ae74SAlex Zinenko 
406443583bSMehdi Amini void mlir::visitUsedValuesDefinedAbove(
414562e389SRiver Riddle     Region &region, Region &limit, function_ref<void(OpOperand *)> callback) {
424291ae74SAlex Zinenko   assert(limit.isAncestor(&region) &&
434291ae74SAlex Zinenko          "expected isolation limit to be an ancestor of the given region");
444291ae74SAlex Zinenko 
454291ae74SAlex Zinenko   // Collect proper ancestors of `limit` upfront to avoid traversing the region
464291ae74SAlex Zinenko   // tree for every value.
474562e389SRiver Riddle   SmallPtrSet<Region *, 4> properAncestors;
481e429540SRiver Riddle   for (auto *reg = limit.getParentRegion(); reg != nullptr;
491e429540SRiver Riddle        reg = reg->getParentRegion()) {
504291ae74SAlex Zinenko     properAncestors.insert(reg);
514291ae74SAlex Zinenko   }
524291ae74SAlex Zinenko 
536443583bSMehdi Amini   region.walk([callback, &properAncestors](Operation *op) {
546443583bSMehdi Amini     for (OpOperand &operand : op->getOpOperands())
556443583bSMehdi Amini       // Callback on values defined in a proper ancestor of region.
562bdf33ccSRiver Riddle       if (properAncestors.count(operand.get().getParentRegion()))
576443583bSMehdi Amini         callback(&operand);
586443583bSMehdi Amini   });
596443583bSMehdi Amini }
606443583bSMehdi Amini 
616443583bSMehdi Amini void mlir::visitUsedValuesDefinedAbove(
624562e389SRiver Riddle     MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) {
636443583bSMehdi Amini   for (Region &region : regions)
646443583bSMehdi Amini     visitUsedValuesDefinedAbove(region, region, callback);
656443583bSMehdi Amini }
666443583bSMehdi Amini 
676443583bSMehdi Amini void mlir::getUsedValuesDefinedAbove(Region &region, Region &limit,
684efb7754SRiver Riddle                                      SetVector<Value> &values) {
696443583bSMehdi Amini   visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
706443583bSMehdi Amini     values.insert(operand->get());
714291ae74SAlex Zinenko   });
724291ae74SAlex Zinenko }
73ce702fc8SMehdi Amini 
744562e389SRiver Riddle void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
754efb7754SRiver Riddle                                      SetVector<Value> &values) {
76ce702fc8SMehdi Amini   for (Region &region : regions)
77ce702fc8SMehdi Amini     getUsedValuesDefinedAbove(region, region, values);
78ce702fc8SMehdi Amini }
79fafb708bSRiver Riddle 
80fafb708bSRiver Riddle //===----------------------------------------------------------------------===//
81da784e77SMahesh Ravishankar // Make block isolated from above.
82da784e77SMahesh Ravishankar //===----------------------------------------------------------------------===//
83da784e77SMahesh Ravishankar 
84da784e77SMahesh Ravishankar SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
85da784e77SMahesh Ravishankar     RewriterBase &rewriter, Region &region,
86da784e77SMahesh Ravishankar     llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion) {
87da784e77SMahesh Ravishankar 
88da784e77SMahesh Ravishankar   // Get initial list of values used within region but defined above.
89da784e77SMahesh Ravishankar   llvm::SetVector<Value> initialCapturedValues;
90da784e77SMahesh Ravishankar   mlir::getUsedValuesDefinedAbove(region, initialCapturedValues);
91da784e77SMahesh Ravishankar 
92da784e77SMahesh Ravishankar   std::deque<Value> worklist(initialCapturedValues.begin(),
93da784e77SMahesh Ravishankar                              initialCapturedValues.end());
94da784e77SMahesh Ravishankar   llvm::DenseSet<Value> visited;
95da784e77SMahesh Ravishankar   llvm::DenseSet<Operation *> visitedOps;
96da784e77SMahesh Ravishankar 
97da784e77SMahesh Ravishankar   llvm::SetVector<Value> finalCapturedValues;
98da784e77SMahesh Ravishankar   SmallVector<Operation *> clonedOperations;
99da784e77SMahesh Ravishankar   while (!worklist.empty()) {
100da784e77SMahesh Ravishankar     Value currValue = worklist.front();
101da784e77SMahesh Ravishankar     worklist.pop_front();
102da784e77SMahesh Ravishankar     if (visited.count(currValue))
103da784e77SMahesh Ravishankar       continue;
104da784e77SMahesh Ravishankar     visited.insert(currValue);
105da784e77SMahesh Ravishankar 
106da784e77SMahesh Ravishankar     Operation *definingOp = currValue.getDefiningOp();
107da784e77SMahesh Ravishankar     if (!definingOp || visitedOps.count(definingOp)) {
108da784e77SMahesh Ravishankar       finalCapturedValues.insert(currValue);
109da784e77SMahesh Ravishankar       continue;
110da784e77SMahesh Ravishankar     }
111da784e77SMahesh Ravishankar     visitedOps.insert(definingOp);
112da784e77SMahesh Ravishankar 
113da784e77SMahesh Ravishankar     if (!cloneOperationIntoRegion(definingOp)) {
114da784e77SMahesh Ravishankar       // Defining operation isnt cloned, so add the current value to final
115da784e77SMahesh Ravishankar       // captured values list.
116da784e77SMahesh Ravishankar       finalCapturedValues.insert(currValue);
117da784e77SMahesh Ravishankar       continue;
118da784e77SMahesh Ravishankar     }
119da784e77SMahesh Ravishankar 
120da784e77SMahesh Ravishankar     // Add all operands of the operation to the worklist and mark the op as to
121da784e77SMahesh Ravishankar     // be cloned.
122da784e77SMahesh Ravishankar     for (Value operand : definingOp->getOperands()) {
123da784e77SMahesh Ravishankar       if (visited.count(operand))
124da784e77SMahesh Ravishankar         continue;
125da784e77SMahesh Ravishankar       worklist.push_back(operand);
126da784e77SMahesh Ravishankar     }
127da784e77SMahesh Ravishankar     clonedOperations.push_back(definingOp);
128da784e77SMahesh Ravishankar   }
129da784e77SMahesh Ravishankar 
130da784e77SMahesh Ravishankar   // The operations to be cloned need to be ordered in topological order
131da784e77SMahesh Ravishankar   // so that they can be cloned into the region without violating use-def
132da784e77SMahesh Ravishankar   // chains.
133da784e77SMahesh Ravishankar   mlir::computeTopologicalSorting(clonedOperations);
134da784e77SMahesh Ravishankar 
135da784e77SMahesh Ravishankar   OpBuilder::InsertionGuard g(rewriter);
136da784e77SMahesh Ravishankar   // Collect types of existing block
137da784e77SMahesh Ravishankar   Block *entryBlock = &region.front();
138da784e77SMahesh Ravishankar   SmallVector<Type> newArgTypes =
139da784e77SMahesh Ravishankar       llvm::to_vector(entryBlock->getArgumentTypes());
140da784e77SMahesh Ravishankar   SmallVector<Location> newArgLocs = llvm::to_vector(llvm::map_range(
141da784e77SMahesh Ravishankar       entryBlock->getArguments(), [](BlockArgument b) { return b.getLoc(); }));
142da784e77SMahesh Ravishankar 
143da784e77SMahesh Ravishankar   // Append the types of the captured values.
144da784e77SMahesh Ravishankar   for (auto value : finalCapturedValues) {
145da784e77SMahesh Ravishankar     newArgTypes.push_back(value.getType());
146da784e77SMahesh Ravishankar     newArgLocs.push_back(value.getLoc());
147da784e77SMahesh Ravishankar   }
148da784e77SMahesh Ravishankar 
149da784e77SMahesh Ravishankar   // Create a new entry block.
150da784e77SMahesh Ravishankar   Block *newEntryBlock =
151da784e77SMahesh Ravishankar       rewriter.createBlock(&region, region.begin(), newArgTypes, newArgLocs);
152da784e77SMahesh Ravishankar   auto newEntryBlockArgs = newEntryBlock->getArguments();
153da784e77SMahesh Ravishankar 
154da784e77SMahesh Ravishankar   // Create a mapping between the captured values and the new arguments added.
155da784e77SMahesh Ravishankar   IRMapping map;
156da784e77SMahesh Ravishankar   auto replaceIfFn = [&](OpOperand &use) {
157da784e77SMahesh Ravishankar     return use.getOwner()->getBlock()->getParent() == &region;
158da784e77SMahesh Ravishankar   };
159da784e77SMahesh Ravishankar   for (auto [arg, capturedVal] :
160da784e77SMahesh Ravishankar        llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()),
161da784e77SMahesh Ravishankar                  finalCapturedValues)) {
162da784e77SMahesh Ravishankar     map.map(capturedVal, arg);
163da784e77SMahesh Ravishankar     rewriter.replaceUsesWithIf(capturedVal, arg, replaceIfFn);
164da784e77SMahesh Ravishankar   }
165da784e77SMahesh Ravishankar   rewriter.setInsertionPointToStart(newEntryBlock);
166d1cafe2dSMehdi Amini   for (auto *clonedOp : clonedOperations) {
167da784e77SMahesh Ravishankar     Operation *newOp = rewriter.clone(*clonedOp, map);
168f1aa7837SMatthias Springer     rewriter.replaceOpUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn);
169da784e77SMahesh Ravishankar   }
170da784e77SMahesh Ravishankar   rewriter.mergeBlocks(
171da784e77SMahesh Ravishankar       entryBlock, newEntryBlock,
172da784e77SMahesh Ravishankar       newEntryBlock->getArguments().take_front(entryBlock->getNumArguments()));
173da784e77SMahesh Ravishankar   return llvm::to_vector(finalCapturedValues);
174da784e77SMahesh Ravishankar }
175da784e77SMahesh Ravishankar 
176da784e77SMahesh Ravishankar //===----------------------------------------------------------------------===//
177fafb708bSRiver Riddle // Unreachable Block Elimination
178fafb708bSRiver Riddle //===----------------------------------------------------------------------===//
179fafb708bSRiver Riddle 
180fafb708bSRiver Riddle /// Erase the unreachable blocks within the provided regions. Returns success
181fafb708bSRiver Riddle /// if any blocks were erased, failure otherwise.
182fafb708bSRiver Riddle // TODO: We could likely merge this with the DCE algorithm below.
18378d69182SValentin Clement LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
184d75a611aSRiver Riddle                                            MutableArrayRef<Region> regions) {
185fafb708bSRiver Riddle   // Set of blocks found to be reachable within a given region.
186fafb708bSRiver Riddle   llvm::df_iterator_default_set<Block *, 16> reachable;
187fafb708bSRiver Riddle   // If any blocks were found to be dead.
188fafb708bSRiver Riddle   bool erasedDeadBlocks = false;
189fafb708bSRiver Riddle 
190fafb708bSRiver Riddle   SmallVector<Region *, 1> worklist;
191fafb708bSRiver Riddle   worklist.reserve(regions.size());
192fafb708bSRiver Riddle   for (Region &region : regions)
193fafb708bSRiver Riddle     worklist.push_back(&region);
194fafb708bSRiver Riddle   while (!worklist.empty()) {
195fafb708bSRiver Riddle     Region *region = worklist.pop_back_val();
196fafb708bSRiver Riddle     if (region->empty())
197fafb708bSRiver Riddle       continue;
198fafb708bSRiver Riddle 
199fafb708bSRiver Riddle     // If this is a single block region, just collect the nested regions.
200fafb708bSRiver Riddle     if (std::next(region->begin()) == region->end()) {
201fafb708bSRiver Riddle       for (Operation &op : region->front())
202fafb708bSRiver Riddle         for (Region &region : op.getRegions())
203fafb708bSRiver Riddle           worklist.push_back(&region);
204fafb708bSRiver Riddle       continue;
205fafb708bSRiver Riddle     }
206fafb708bSRiver Riddle 
207fafb708bSRiver Riddle     // Mark all reachable blocks.
208fafb708bSRiver Riddle     reachable.clear();
209fafb708bSRiver Riddle     for (Block *block : depth_first_ext(&region->front(), reachable))
210fafb708bSRiver Riddle       (void)block /* Mark all reachable blocks */;
211fafb708bSRiver Riddle 
212fafb708bSRiver Riddle     // Collect all of the dead blocks and push the live regions onto the
213fafb708bSRiver Riddle     // worklist.
214fafb708bSRiver Riddle     for (Block &block : llvm::make_early_inc_range(*region)) {
215fafb708bSRiver Riddle       if (!reachable.count(&block)) {
216fafb708bSRiver Riddle         block.dropAllDefinedValueUses();
217d75a611aSRiver Riddle         rewriter.eraseBlock(&block);
218fafb708bSRiver Riddle         erasedDeadBlocks = true;
219fafb708bSRiver Riddle         continue;
220fafb708bSRiver Riddle       }
221fafb708bSRiver Riddle 
222fafb708bSRiver Riddle       // Walk any regions within this block.
223fafb708bSRiver Riddle       for (Operation &op : block)
224fafb708bSRiver Riddle         for (Region &region : op.getRegions())
225fafb708bSRiver Riddle           worklist.push_back(&region);
226fafb708bSRiver Riddle     }
227fafb708bSRiver Riddle   }
228fafb708bSRiver Riddle 
229fafb708bSRiver Riddle   return success(erasedDeadBlocks);
230fafb708bSRiver Riddle }
231fafb708bSRiver Riddle 
232fafb708bSRiver Riddle //===----------------------------------------------------------------------===//
233fafb708bSRiver Riddle // Dead Code Elimination
234fafb708bSRiver Riddle //===----------------------------------------------------------------------===//
235fafb708bSRiver Riddle 
236fafb708bSRiver Riddle namespace {
237fafb708bSRiver Riddle /// Data structure used to track which values have already been proved live.
238fafb708bSRiver Riddle ///
239fafb708bSRiver Riddle /// Because Operation's can have multiple results, this data structure tracks
240fafb708bSRiver Riddle /// liveness for both Value's and Operation's to avoid having to look through
241fafb708bSRiver Riddle /// all Operation results when analyzing a use.
242fafb708bSRiver Riddle ///
243fafb708bSRiver Riddle /// This data structure essentially tracks the dataflow lattice.
244fafb708bSRiver Riddle /// The set of values/ops proved live increases monotonically to a fixed-point.
245fafb708bSRiver Riddle class LiveMap {
246fafb708bSRiver Riddle public:
247fafb708bSRiver Riddle   /// Value methods.
2484e02eb80SRiver Riddle   bool wasProvenLive(Value value) {
2494e02eb80SRiver Riddle     // TODO: For results that are removable, e.g. for region based control flow,
2504e02eb80SRiver Riddle     // we could allow for these values to be tracked independently.
2515550c821STres Popp     if (OpResult result = dyn_cast<OpResult>(value))
2524e02eb80SRiver Riddle       return wasProvenLive(result.getOwner());
2535550c821STres Popp     return wasProvenLive(cast<BlockArgument>(value));
2544e02eb80SRiver Riddle   }
2554e02eb80SRiver Riddle   bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); }
256e62a6956SRiver Riddle   void setProvedLive(Value value) {
2574e02eb80SRiver Riddle     // TODO: For results that are removable, e.g. for region based control flow,
2584e02eb80SRiver Riddle     // we could allow for these values to be tracked independently.
2595550c821STres Popp     if (OpResult result = dyn_cast<OpResult>(value))
2604e02eb80SRiver Riddle       return setProvedLive(result.getOwner());
2615550c821STres Popp     setProvedLive(cast<BlockArgument>(value));
2624e02eb80SRiver Riddle   }
2634e02eb80SRiver Riddle   void setProvedLive(BlockArgument arg) {
2644e02eb80SRiver Riddle     changed |= liveValues.insert(arg).second;
265fafb708bSRiver Riddle   }
266fafb708bSRiver Riddle 
267fafb708bSRiver Riddle   /// Operation methods.
268fafb708bSRiver Riddle   bool wasProvenLive(Operation *op) { return liveOps.count(op); }
269fafb708bSRiver Riddle   void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; }
270fafb708bSRiver Riddle 
271fafb708bSRiver Riddle   /// Methods for tracking if we have reached a fixed-point.
272fafb708bSRiver Riddle   void resetChanged() { changed = false; }
273fafb708bSRiver Riddle   bool hasChanged() { return changed; }
274fafb708bSRiver Riddle 
275fafb708bSRiver Riddle private:
276fafb708bSRiver Riddle   bool changed = false;
277e62a6956SRiver Riddle   DenseSet<Value> liveValues;
278fafb708bSRiver Riddle   DenseSet<Operation *> liveOps;
279fafb708bSRiver Riddle };
280fafb708bSRiver Riddle } // namespace
281fafb708bSRiver Riddle 
282fafb708bSRiver Riddle static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
283fafb708bSRiver Riddle   Operation *owner = use.getOwner();
284fafb708bSRiver Riddle   unsigned operandIndex = use.getOperandNumber();
285fafb708bSRiver Riddle   // This pass generally treats all uses of an op as live if the op itself is
286fafb708bSRiver Riddle   // considered live. However, for successor operands to terminators we need a
287fafb708bSRiver Riddle   // finer-grained notion where we deduce liveness for operands individually.
288fafb708bSRiver Riddle   // The reason for this is easiest to think about in terms of a classical phi
289fafb708bSRiver Riddle   // node based SSA IR, where each successor operand is really an operand to a
290fafb708bSRiver Riddle   // *separate* phi node, rather than all operands to the branch itself as with
291fafb708bSRiver Riddle   // the block argument representation that MLIR uses.
292fafb708bSRiver Riddle   //
293fafb708bSRiver Riddle   // And similarly, because each successor operand is really an operand to a phi
294fafb708bSRiver Riddle   // node, rather than to the terminator op itself, a terminator op can't e.g.
295fafb708bSRiver Riddle   // "print" the value of a successor operand.
296fe7c0d90SRiver Riddle   if (owner->hasTrait<OpTrait::IsTerminator>()) {
297cb177712SRiver Riddle     if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
298cb177712SRiver Riddle       if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
299fafb708bSRiver Riddle         return !liveMap.wasProvenLive(*arg);
300fafb708bSRiver Riddle     return false;
301fafb708bSRiver Riddle   }
302fafb708bSRiver Riddle   return false;
303fafb708bSRiver Riddle }
304fafb708bSRiver Riddle 
305e62a6956SRiver Riddle static void processValue(Value value, LiveMap &liveMap) {
3062bdf33ccSRiver Riddle   bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) {
307fafb708bSRiver Riddle     if (isUseSpeciallyKnownDead(use, liveMap))
308fafb708bSRiver Riddle       return false;
309fafb708bSRiver Riddle     return liveMap.wasProvenLive(use.getOwner());
310fafb708bSRiver Riddle   });
311fafb708bSRiver Riddle   if (provedLive)
312fafb708bSRiver Riddle     liveMap.setProvedLive(value);
313fafb708bSRiver Riddle }
314fafb708bSRiver Riddle 
315fafb708bSRiver Riddle static void propagateLiveness(Region &region, LiveMap &liveMap);
316cb177712SRiver Riddle 
317cb177712SRiver Riddle static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
318cb177712SRiver Riddle   // Terminators are always live.
319cb177712SRiver Riddle   liveMap.setProvedLive(op);
320cb177712SRiver Riddle 
321cb177712SRiver Riddle   // Check to see if we can reason about the successor operands and mutate them.
322cb177712SRiver Riddle   BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
3230752d98cSRiver Riddle   if (!branchInterface) {
324cb177712SRiver Riddle     for (Block *successor : op->getSuccessors())
325cb177712SRiver Riddle       for (BlockArgument arg : successor->getArguments())
326cb177712SRiver Riddle         liveMap.setProvedLive(arg);
327cb177712SRiver Riddle     return;
328cb177712SRiver Riddle   }
329cb177712SRiver Riddle 
3300c789db5SMarkus Böck   // If we can't reason about the operand to a successor, conservatively mark
3310c789db5SMarkus Böck   // it as live.
332cb177712SRiver Riddle   for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
3330c789db5SMarkus Böck     SuccessorOperands successorOperands =
3340c789db5SMarkus Böck         branchInterface.getSuccessorOperands(i);
3350c789db5SMarkus Böck     for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount();
3360c789db5SMarkus Böck          opI != opE; ++opI)
3370c789db5SMarkus Böck       liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI));
338cb177712SRiver Riddle   }
339cb177712SRiver Riddle }
340cb177712SRiver Riddle 
341fafb708bSRiver Riddle static void propagateLiveness(Operation *op, LiveMap &liveMap) {
342fafb708bSRiver Riddle   // Recurse on any regions the op has.
343fafb708bSRiver Riddle   for (Region &region : op->getRegions())
344fafb708bSRiver Riddle     propagateLiveness(region, liveMap);
345fafb708bSRiver Riddle 
346cb177712SRiver Riddle   // Process terminator operations.
347fe7c0d90SRiver Riddle   if (op->hasTrait<OpTrait::IsTerminator>())
348cb177712SRiver Riddle     return propagateTerminatorLiveness(op, liveMap);
349cb177712SRiver Riddle 
3504e02eb80SRiver Riddle   // Don't reprocess live operations.
3514e02eb80SRiver Riddle   if (liveMap.wasProvenLive(op))
352fafb708bSRiver Riddle     return;
3534e02eb80SRiver Riddle 
3544e02eb80SRiver Riddle   // Process the op itself.
3554e02eb80SRiver Riddle   if (!wouldOpBeTriviallyDead(op))
3564e02eb80SRiver Riddle     return liveMap.setProvedLive(op);
3574e02eb80SRiver Riddle 
3584e02eb80SRiver Riddle   // If the op isn't intrinsically alive, check it's results.
359e62a6956SRiver Riddle   for (Value value : op->getResults())
360fafb708bSRiver Riddle     processValue(value, liveMap);
361fafb708bSRiver Riddle }
362fafb708bSRiver Riddle 
363fafb708bSRiver Riddle static void propagateLiveness(Region &region, LiveMap &liveMap) {
364fafb708bSRiver Riddle   if (region.empty())
365fafb708bSRiver Riddle     return;
366fafb708bSRiver Riddle 
367fafb708bSRiver Riddle   for (Block *block : llvm::post_order(&region.front())) {
368fafb708bSRiver Riddle     // We process block arguments after the ops in the block, to promote
369fafb708bSRiver Riddle     // faster convergence to a fixed point (we try to visit uses before defs).
370fafb708bSRiver Riddle     for (Operation &op : llvm::reverse(block->getOperations()))
371fafb708bSRiver Riddle       propagateLiveness(&op, liveMap);
3724e02eb80SRiver Riddle 
3734e02eb80SRiver Riddle     // We currently do not remove entry block arguments, so there is no need to
3744e02eb80SRiver Riddle     // track their liveness.
3754e02eb80SRiver Riddle     // TODO: We could track these and enable removing dead operands/arguments
3764e02eb80SRiver Riddle     // from region control flow operations.
3774e02eb80SRiver Riddle     if (block->isEntryBlock())
3784e02eb80SRiver Riddle       continue;
3794e02eb80SRiver Riddle 
3804e02eb80SRiver Riddle     for (Value value : block->getArguments()) {
3814e02eb80SRiver Riddle       if (!liveMap.wasProvenLive(value))
382fafb708bSRiver Riddle         processValue(value, liveMap);
383fafb708bSRiver Riddle     }
384fafb708bSRiver Riddle   }
3854e02eb80SRiver Riddle }
386fafb708bSRiver Riddle 
387fafb708bSRiver Riddle static void eraseTerminatorSuccessorOperands(Operation *terminator,
388fafb708bSRiver Riddle                                              LiveMap &liveMap) {
389cb177712SRiver Riddle   BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
390cb177712SRiver Riddle   if (!branchOp)
391cb177712SRiver Riddle     return;
392cb177712SRiver Riddle 
393fafb708bSRiver Riddle   for (unsigned succI = 0, succE = terminator->getNumSuccessors();
394fafb708bSRiver Riddle        succI < succE; succI++) {
395fafb708bSRiver Riddle     // Iterating successors in reverse is not strictly needed, since we
396fafb708bSRiver Riddle     // aren't erasing any successors. But it is slightly more efficient
397fafb708bSRiver Riddle     // since it will promote later operands of the terminator being erased
398fafb708bSRiver Riddle     // first, reducing the quadratic-ness.
399fafb708bSRiver Riddle     unsigned succ = succE - succI - 1;
4000c789db5SMarkus Böck     SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ);
401cb177712SRiver Riddle     Block *successor = terminator->getSuccessor(succ);
402cb177712SRiver Riddle 
4030c789db5SMarkus Böck     for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) {
404fafb708bSRiver Riddle       // Iterating args in reverse is needed for correctness, to avoid
405fafb708bSRiver Riddle       // shifting later args when earlier args are erased.
406fafb708bSRiver Riddle       unsigned arg = argE - argI - 1;
407cb177712SRiver Riddle       if (!liveMap.wasProvenLive(successor->getArgument(arg)))
4080c789db5SMarkus Böck         succOperands.erase(arg);
409fafb708bSRiver Riddle     }
410fafb708bSRiver Riddle   }
411fafb708bSRiver Riddle }
412fafb708bSRiver Riddle 
413d75a611aSRiver Riddle static LogicalResult deleteDeadness(RewriterBase &rewriter,
414d75a611aSRiver Riddle                                     MutableArrayRef<Region> regions,
415fafb708bSRiver Riddle                                     LiveMap &liveMap) {
416fafb708bSRiver Riddle   bool erasedAnything = false;
417fafb708bSRiver Riddle   for (Region &region : regions) {
418fafb708bSRiver Riddle     if (region.empty())
419fafb708bSRiver Riddle       continue;
420973ddb7dSMehdi Amini     bool hasSingleBlock = llvm::hasSingleElement(region);
421fafb708bSRiver Riddle 
422f178c13fSAndrew Young     // Delete every operation that is not live. Graph regions may have cycles
423f178c13fSAndrew Young     // in the use-def graph, so we must explicitly dropAllUses() from each
424f178c13fSAndrew Young     // operation as we erase it. Visiting the operations in post-order
425f178c13fSAndrew Young     // guarantees that in SSA CFG regions value uses are removed before defs,
426f178c13fSAndrew Young     // which makes dropAllUses() a no-op.
427fafb708bSRiver Riddle     for (Block *block : llvm::post_order(&region.front())) {
428973ddb7dSMehdi Amini       if (!hasSingleBlock)
429fafb708bSRiver Riddle         eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
430fafb708bSRiver Riddle       for (Operation &childOp :
431fafb708bSRiver Riddle            llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
432fafb708bSRiver Riddle         if (!liveMap.wasProvenLive(&childOp)) {
433fafb708bSRiver Riddle           erasedAnything = true;
434f178c13fSAndrew Young           childOp.dropAllUses();
435d75a611aSRiver Riddle           rewriter.eraseOp(&childOp);
4364e02eb80SRiver Riddle         } else {
437d75a611aSRiver Riddle           erasedAnything |= succeeded(
438d75a611aSRiver Riddle               deleteDeadness(rewriter, childOp.getRegions(), liveMap));
439fafb708bSRiver Riddle         }
440fafb708bSRiver Riddle       }
441fafb708bSRiver Riddle     }
442fafb708bSRiver Riddle     // Delete block arguments.
443fafb708bSRiver Riddle     // The entry block has an unknown contract with their enclosing block, so
444fafb708bSRiver Riddle     // skip it.
4454ea92a05SRiver Riddle     for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
4464e02eb80SRiver Riddle       block.eraseArguments(
4474e02eb80SRiver Riddle           [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); });
448fafb708bSRiver Riddle     }
449fafb708bSRiver Riddle   }
450fafb708bSRiver Riddle   return success(erasedAnything);
451fafb708bSRiver Riddle }
452fafb708bSRiver Riddle 
453fafb708bSRiver Riddle // This function performs a simple dead code elimination algorithm over the
454fafb708bSRiver Riddle // given regions.
455fafb708bSRiver Riddle //
456fafb708bSRiver Riddle // The overall goal is to prove that Values are dead, which allows deleting ops
457fafb708bSRiver Riddle // and block arguments.
458fafb708bSRiver Riddle //
459fafb708bSRiver Riddle // This uses an optimistic algorithm that assumes everything is dead until
460fafb708bSRiver Riddle // proved otherwise, allowing it to delete recursively dead cycles.
461fafb708bSRiver Riddle //
462fafb708bSRiver Riddle // This is a simple fixed-point dataflow analysis algorithm on a lattice
463fafb708bSRiver Riddle // {Dead,Alive}. Because liveness flows backward, we generally try to
464fafb708bSRiver Riddle // iterate everything backward to speed up convergence to the fixed-point. This
465fafb708bSRiver Riddle // allows for being able to delete recursively dead cycles of the use-def graph,
466fafb708bSRiver Riddle // including block arguments.
467fafb708bSRiver Riddle //
468fafb708bSRiver Riddle // This function returns success if any operations or arguments were deleted,
469fafb708bSRiver Riddle // failure otherwise.
47078d69182SValentin Clement LogicalResult mlir::runRegionDCE(RewriterBase &rewriter,
471d75a611aSRiver Riddle                                  MutableArrayRef<Region> regions) {
472fafb708bSRiver Riddle   LiveMap liveMap;
473fafb708bSRiver Riddle   do {
474fafb708bSRiver Riddle     liveMap.resetChanged();
475fafb708bSRiver Riddle 
476fafb708bSRiver Riddle     for (Region &region : regions)
477fafb708bSRiver Riddle       propagateLiveness(region, liveMap);
478fafb708bSRiver Riddle   } while (liveMap.hasChanged());
479fafb708bSRiver Riddle 
480d75a611aSRiver Riddle   return deleteDeadness(rewriter, regions, liveMap);
481fafb708bSRiver Riddle }
482fafb708bSRiver Riddle 
483fafb708bSRiver Riddle //===----------------------------------------------------------------------===//
484469c02d0SRiver Riddle // Block Merging
485469c02d0SRiver Riddle //===----------------------------------------------------------------------===//
486469c02d0SRiver Riddle 
487469c02d0SRiver Riddle //===----------------------------------------------------------------------===//
488469c02d0SRiver Riddle // BlockEquivalenceData
489469c02d0SRiver Riddle 
490469c02d0SRiver Riddle namespace {
491469c02d0SRiver Riddle /// This class contains the information for comparing the equivalencies of two
492469c02d0SRiver Riddle /// blocks. Blocks are considered equivalent if they contain the same operations
493469c02d0SRiver Riddle /// in the same order. The only allowed divergence is for operands that come
494469c02d0SRiver Riddle /// from sources outside of the parent block, i.e. the uses of values produced
495469c02d0SRiver Riddle /// within the block must be equivalent.
496469c02d0SRiver Riddle ///   e.g.,
497469c02d0SRiver Riddle /// Equivalent:
498469c02d0SRiver Riddle ///  ^bb1(%arg0: i32)
499469c02d0SRiver Riddle ///    return %arg0, %foo : i32, i32
500469c02d0SRiver Riddle ///  ^bb2(%arg1: i32)
501469c02d0SRiver Riddle ///    return %arg1, %bar : i32, i32
502469c02d0SRiver Riddle /// Not Equivalent:
503469c02d0SRiver Riddle ///  ^bb1(%arg0: i32)
504469c02d0SRiver Riddle ///    return %foo, %arg0 : i32, i32
505469c02d0SRiver Riddle ///  ^bb2(%arg1: i32)
506469c02d0SRiver Riddle ///    return %arg1, %bar : i32, i32
507469c02d0SRiver Riddle struct BlockEquivalenceData {
508469c02d0SRiver Riddle   BlockEquivalenceData(Block *block);
509469c02d0SRiver Riddle 
510469c02d0SRiver Riddle   /// Return the order index for the given value that is within the block of
511469c02d0SRiver Riddle   /// this data.
512469c02d0SRiver Riddle   unsigned getOrderOf(Value value) const;
513469c02d0SRiver Riddle 
514469c02d0SRiver Riddle   /// The block this data refers to.
515469c02d0SRiver Riddle   Block *block;
516469c02d0SRiver Riddle   /// A hash value for this block.
517469c02d0SRiver Riddle   llvm::hash_code hash;
518469c02d0SRiver Riddle   /// A map of result producing operations to their relative orders within this
519469c02d0SRiver Riddle   /// block. The order of an operation is the number of defined values that are
520469c02d0SRiver Riddle   /// produced within the block before this operation.
521469c02d0SRiver Riddle   DenseMap<Operation *, unsigned> opOrderIndex;
522469c02d0SRiver Riddle };
523be0a7e9fSMehdi Amini } // namespace
524469c02d0SRiver Riddle 
525469c02d0SRiver Riddle BlockEquivalenceData::BlockEquivalenceData(Block *block)
526469c02d0SRiver Riddle     : block(block), hash(0) {
527469c02d0SRiver Riddle   unsigned orderIt = block->getNumArguments();
528469c02d0SRiver Riddle   for (Operation &op : *block) {
529469c02d0SRiver Riddle     if (unsigned numResults = op.getNumResults()) {
530469c02d0SRiver Riddle       opOrderIndex.try_emplace(&op, orderIt);
531469c02d0SRiver Riddle       orderIt += numResults;
532469c02d0SRiver Riddle     }
533469c02d0SRiver Riddle     auto opHash = OperationEquivalence::computeHash(
5340be5d1a9SMehdi Amini         &op, OperationEquivalence::ignoreHashValue,
5350be5d1a9SMehdi Amini         OperationEquivalence::ignoreHashValue,
5360be5d1a9SMehdi Amini         OperationEquivalence::IgnoreLocations);
537469c02d0SRiver Riddle     hash = llvm::hash_combine(hash, opHash);
538469c02d0SRiver Riddle   }
539469c02d0SRiver Riddle }
540469c02d0SRiver Riddle 
541469c02d0SRiver Riddle unsigned BlockEquivalenceData::getOrderOf(Value value) const {
542469c02d0SRiver Riddle   assert(value.getParentBlock() == block && "expected value of this block");
543469c02d0SRiver Riddle 
544469c02d0SRiver Riddle   // Arguments use the argument number as the order index.
5455550c821STres Popp   if (BlockArgument arg = dyn_cast<BlockArgument>(value))
546469c02d0SRiver Riddle     return arg.getArgNumber();
547469c02d0SRiver Riddle 
548469c02d0SRiver Riddle   // Otherwise, the result order is offset from the parent op's order.
5495550c821STres Popp   OpResult result = cast<OpResult>(value);
550469c02d0SRiver Riddle   auto opOrderIt = opOrderIndex.find(result.getDefiningOp());
551469c02d0SRiver Riddle   assert(opOrderIt != opOrderIndex.end() && "expected op to have an order");
552469c02d0SRiver Riddle   return opOrderIt->second + result.getResultNumber();
553469c02d0SRiver Riddle }
554469c02d0SRiver Riddle 
555469c02d0SRiver Riddle //===----------------------------------------------------------------------===//
556469c02d0SRiver Riddle // BlockMergeCluster
557469c02d0SRiver Riddle 
558469c02d0SRiver Riddle namespace {
559469c02d0SRiver Riddle /// This class represents a cluster of blocks to be merged together.
560469c02d0SRiver Riddle class BlockMergeCluster {
561469c02d0SRiver Riddle public:
562469c02d0SRiver Riddle   BlockMergeCluster(BlockEquivalenceData &&leaderData)
563469c02d0SRiver Riddle       : leaderData(std::move(leaderData)) {}
564469c02d0SRiver Riddle 
565469c02d0SRiver Riddle   /// Attempt to add the given block to this cluster. Returns success if the
566469c02d0SRiver Riddle   /// block was merged, failure otherwise.
567469c02d0SRiver Riddle   LogicalResult addToCluster(BlockEquivalenceData &blockData);
568469c02d0SRiver Riddle 
569469c02d0SRiver Riddle   /// Try to merge all of the blocks within this cluster into the leader block.
570d75a611aSRiver Riddle   LogicalResult merge(RewriterBase &rewriter);
571469c02d0SRiver Riddle 
572469c02d0SRiver Riddle private:
573469c02d0SRiver Riddle   /// The equivalence data for the leader of the cluster.
574469c02d0SRiver Riddle   BlockEquivalenceData leaderData;
575469c02d0SRiver Riddle 
576469c02d0SRiver Riddle   /// The set of blocks that can be merged into the leader.
577469c02d0SRiver Riddle   llvm::SmallSetVector<Block *, 1> blocksToMerge;
578469c02d0SRiver Riddle 
579469c02d0SRiver Riddle   /// A set of operand+index pairs that correspond to operands that need to be
580469c02d0SRiver Riddle   /// replaced by arguments when the cluster gets merged.
581469c02d0SRiver Riddle   std::set<std::pair<int, int>> operandsToMerge;
582469c02d0SRiver Riddle };
583be0a7e9fSMehdi Amini } // namespace
584469c02d0SRiver Riddle 
585469c02d0SRiver Riddle LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
586469c02d0SRiver Riddle   if (leaderData.hash != blockData.hash)
587469c02d0SRiver Riddle     return failure();
588469c02d0SRiver Riddle   Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block;
589469c02d0SRiver Riddle   if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes())
590469c02d0SRiver Riddle     return failure();
591469c02d0SRiver Riddle 
592469c02d0SRiver Riddle   // A set of operands that mismatch between the leader and the new block.
593469c02d0SRiver Riddle   SmallVector<std::pair<int, int>, 8> mismatchedOperands;
594469c02d0SRiver Riddle   auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end();
595469c02d0SRiver Riddle   auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
596469c02d0SRiver Riddle   for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
597469c02d0SRiver Riddle     // Check that the operations are equivalent.
598469c02d0SRiver Riddle     if (!OperationEquivalence::isEquivalentTo(
5990be5d1a9SMehdi Amini             &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
600c864288dSMatthias Springer             /*markEquivalent=*/nullptr,
6010be5d1a9SMehdi Amini             OperationEquivalence::Flags::IgnoreLocations))
602469c02d0SRiver Riddle       return failure();
603469c02d0SRiver Riddle 
604469c02d0SRiver Riddle     // Compare the operands of the two operations. If the operand is within
605469c02d0SRiver Riddle     // the block, it must refer to the same operation.
606469c02d0SRiver Riddle     auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands();
607469c02d0SRiver Riddle     for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) {
608469c02d0SRiver Riddle       Value lhsOperand = lhsOperands[operand];
609469c02d0SRiver Riddle       Value rhsOperand = rhsOperands[operand];
610469c02d0SRiver Riddle       if (lhsOperand == rhsOperand)
611469c02d0SRiver Riddle         continue;
612474f7639SRiver Riddle       // Check that the types of the operands match.
613474f7639SRiver Riddle       if (lhsOperand.getType() != rhsOperand.getType())
614474f7639SRiver Riddle         return failure();
615469c02d0SRiver Riddle 
616469c02d0SRiver Riddle       // Check that these uses are both external, or both internal.
617469c02d0SRiver Riddle       bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock;
618469c02d0SRiver Riddle       bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock;
619469c02d0SRiver Riddle       if (lhsIsInBlock != rhsIsInBlock)
620469c02d0SRiver Riddle         return failure();
621469c02d0SRiver Riddle       // Let the operands differ if they are defined in a different block. These
622469c02d0SRiver Riddle       // will become new arguments if the blocks get merged.
623469c02d0SRiver Riddle       if (!lhsIsInBlock) {
624c14ba3c4SMarkus Böck 
625c14ba3c4SMarkus Böck         // Check whether the operands aren't the result of an immediate
626c14ba3c4SMarkus Böck         // predecessors terminator. In that case we are not able to use it as a
627c14ba3c4SMarkus Böck         // successor operand when branching to the merged block as it does not
628c14ba3c4SMarkus Böck         // dominate its producing operation.
629c14ba3c4SMarkus Böck         auto isValidSuccessorArg = [](Block *block, Value operand) {
630c14ba3c4SMarkus Böck           if (operand.getDefiningOp() !=
631c14ba3c4SMarkus Böck               operand.getParentBlock()->getTerminator())
632c14ba3c4SMarkus Böck             return true;
633c14ba3c4SMarkus Böck           return !llvm::is_contained(block->getPredecessors(),
634c14ba3c4SMarkus Böck                                      operand.getParentBlock());
635c14ba3c4SMarkus Böck         };
636c14ba3c4SMarkus Böck 
637c14ba3c4SMarkus Böck         if (!isValidSuccessorArg(leaderBlock, lhsOperand) ||
638c14ba3c4SMarkus Böck             !isValidSuccessorArg(mergeBlock, rhsOperand))
639a7865228SMarkus Böck           return failure();
640c14ba3c4SMarkus Böck 
641469c02d0SRiver Riddle         mismatchedOperands.emplace_back(opI, operand);
642469c02d0SRiver Riddle         continue;
643469c02d0SRiver Riddle       }
644469c02d0SRiver Riddle 
645469c02d0SRiver Riddle       // Otherwise, these operands must have the same logical order within the
646469c02d0SRiver Riddle       // parent block.
647469c02d0SRiver Riddle       if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand))
648469c02d0SRiver Riddle         return failure();
649469c02d0SRiver Riddle     }
650469c02d0SRiver Riddle 
651f5c5fd1cSWilliam S. Moses     // If the lhs or rhs has external uses, the blocks cannot be merged as the
652f5c5fd1cSWilliam S. Moses     // merged version of this operation will not be either the lhs or rhs
653f88fab50SKazuaki Ishizaki     // alone (thus semantically incorrect), but some mix dependending on which
654f5c5fd1cSWilliam S. Moses     // block preceeded this.
655f5c5fd1cSWilliam S. Moses     // TODO allow merging of operations when one block does not dominate the
656f5c5fd1cSWilliam S. Moses     // other
657f5c5fd1cSWilliam S. Moses     if (rhsIt->isUsedOutsideOfBlock(mergeBlock) ||
658f5c5fd1cSWilliam S. Moses         lhsIt->isUsedOutsideOfBlock(leaderBlock)) {
659f5c5fd1cSWilliam S. Moses       return failure();
660f5c5fd1cSWilliam S. Moses     }
661469c02d0SRiver Riddle   }
662469c02d0SRiver Riddle   // Make sure that the block sizes are equivalent.
663469c02d0SRiver Riddle   if (lhsIt != lhsE || rhsIt != rhsE)
664469c02d0SRiver Riddle     return failure();
665469c02d0SRiver Riddle 
666469c02d0SRiver Riddle   // If we get here, the blocks are equivalent and can be merged.
667469c02d0SRiver Riddle   operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
668469c02d0SRiver Riddle   blocksToMerge.insert(blockData.block);
669469c02d0SRiver Riddle   return success();
670469c02d0SRiver Riddle }
671469c02d0SRiver Riddle 
672469c02d0SRiver Riddle /// Returns true if the predecessor terminators of the given block can not have
673469c02d0SRiver Riddle /// their operands updated.
674469c02d0SRiver Riddle static bool ableToUpdatePredOperands(Block *block) {
675469c02d0SRiver Riddle   for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
6760c789db5SMarkus Böck     if (!isa<BranchOpInterface>((*it)->getTerminator()))
677469c02d0SRiver Riddle       return false;
678469c02d0SRiver Riddle   }
679469c02d0SRiver Riddle   return true;
680469c02d0SRiver Riddle }
681469c02d0SRiver Riddle 
682441b672bSGiuseppe Rossini /// Prunes the redundant list of new arguments. E.g., if we are passing an
683441b672bSGiuseppe Rossini /// argument list like [x, y, z, x] this would return [x, y, z] and it would
684441b672bSGiuseppe Rossini /// update the `block` (to whom the argument are passed to) accordingly. The new
685441b672bSGiuseppe Rossini /// arguments are passed as arguments at the back of the block, hence we need to
686441b672bSGiuseppe Rossini /// know how many `numOldArguments` were before, in order to correctly replace
687441b672bSGiuseppe Rossini /// the new arguments in the block
688441b672bSGiuseppe Rossini static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
689441b672bSGiuseppe Rossini     const SmallVector<SmallVector<Value, 8>, 2> &newArguments,
690441b672bSGiuseppe Rossini     RewriterBase &rewriter, unsigned numOldArguments, Block *block) {
691441b672bSGiuseppe Rossini 
692441b672bSGiuseppe Rossini   SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
693441b672bSGiuseppe Rossini       newArguments.size(), SmallVector<Value, 8>());
694441b672bSGiuseppe Rossini 
695441b672bSGiuseppe Rossini   if (newArguments.empty())
696441b672bSGiuseppe Rossini     return newArguments;
697441b672bSGiuseppe Rossini 
698441b672bSGiuseppe Rossini   // `newArguments` is a 2D array of size `numLists` x `numArgs`
699441b672bSGiuseppe Rossini   unsigned numLists = newArguments.size();
700441b672bSGiuseppe Rossini   unsigned numArgs = newArguments[0].size();
701441b672bSGiuseppe Rossini 
702441b672bSGiuseppe Rossini   // Map that for each arg index contains the index that we can use in place of
703441b672bSGiuseppe Rossini   // the original index. E.g., if we have newArgs = [x, y, z, x], we will have
704441b672bSGiuseppe Rossini   // idxToReplacement[3] = 0
705441b672bSGiuseppe Rossini   llvm::DenseMap<unsigned, unsigned> idxToReplacement;
706441b672bSGiuseppe Rossini 
707441b672bSGiuseppe Rossini   // This is a useful data structure to track the first appearance of a Value
708441b672bSGiuseppe Rossini   // on a given list of arguments
709441b672bSGiuseppe Rossini   DenseMap<Value, unsigned> firstValueToIdx;
710441b672bSGiuseppe Rossini   for (unsigned j = 0; j < numArgs; ++j) {
711441b672bSGiuseppe Rossini     Value newArg = newArguments[0][j];
712*1e5f32e8SKazu Hirata     firstValueToIdx.try_emplace(newArg, j);
713441b672bSGiuseppe Rossini   }
714441b672bSGiuseppe Rossini 
715441b672bSGiuseppe Rossini   // Go through the first list of arguments (list 0).
716441b672bSGiuseppe Rossini   for (unsigned j = 0; j < numArgs; ++j) {
717441b672bSGiuseppe Rossini     // Look back to see if there are possible redundancies in list 0. Please
718441b672bSGiuseppe Rossini     // note that we are using a map to annotate when an argument was seen first
719441b672bSGiuseppe Rossini     // to avoid a O(N^2) algorithm. This has the drawback that if we have two
720441b672bSGiuseppe Rossini     // lists like:
721441b672bSGiuseppe Rossini     // list0: [%a, %a, %a]
722441b672bSGiuseppe Rossini     // list1: [%c, %b, %b]
723441b672bSGiuseppe Rossini     // We cannot simplify it, because firstValueToIdx[%a] = 0, but we cannot
724441b672bSGiuseppe Rossini     // point list1[1](==%b) or list1[2](==%b) to list1[0](==%c).  However, since
725441b672bSGiuseppe Rossini     // the number of arguments can be potentially unbounded we cannot afford a
726441b672bSGiuseppe Rossini     // O(N^2) algorithm (to search to all the possible pairs) and we need to
727441b672bSGiuseppe Rossini     // accept the trade-off.
728441b672bSGiuseppe Rossini     unsigned k = firstValueToIdx[newArguments[0][j]];
729441b672bSGiuseppe Rossini     if (k == j)
730441b672bSGiuseppe Rossini       continue;
731441b672bSGiuseppe Rossini 
732441b672bSGiuseppe Rossini     bool shouldReplaceJ = true;
733441b672bSGiuseppe Rossini     unsigned replacement = k;
734441b672bSGiuseppe Rossini     // If a possible redundancy is found, then scan the other lists: we
735441b672bSGiuseppe Rossini     // can prune the arguments if and only if they are redundant in every
736441b672bSGiuseppe Rossini     // list.
737441b672bSGiuseppe Rossini     for (unsigned i = 1; i < numLists; ++i)
738441b672bSGiuseppe Rossini       shouldReplaceJ =
739441b672bSGiuseppe Rossini           shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
740441b672bSGiuseppe Rossini     // Save the replacement.
741441b672bSGiuseppe Rossini     if (shouldReplaceJ)
742441b672bSGiuseppe Rossini       idxToReplacement[j] = replacement;
743441b672bSGiuseppe Rossini   }
744441b672bSGiuseppe Rossini 
745441b672bSGiuseppe Rossini   // Populate the pruned argument list.
746441b672bSGiuseppe Rossini   for (unsigned i = 0; i < numLists; ++i)
747441b672bSGiuseppe Rossini     for (unsigned j = 0; j < numArgs; ++j)
748441b672bSGiuseppe Rossini       if (!idxToReplacement.contains(j))
749441b672bSGiuseppe Rossini         newArgumentsPruned[i].push_back(newArguments[i][j]);
750441b672bSGiuseppe Rossini 
751441b672bSGiuseppe Rossini   // Replace the block's redundant arguments.
752441b672bSGiuseppe Rossini   SmallVector<unsigned> toErase;
753441b672bSGiuseppe Rossini   for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
754441b672bSGiuseppe Rossini     if (idxToReplacement.contains(idx)) {
755441b672bSGiuseppe Rossini       Value oldArg = block->getArgument(numOldArguments + idx);
756441b672bSGiuseppe Rossini       Value newArg =
757441b672bSGiuseppe Rossini           block->getArgument(numOldArguments + idxToReplacement[idx]);
758441b672bSGiuseppe Rossini       rewriter.replaceAllUsesWith(oldArg, newArg);
759441b672bSGiuseppe Rossini       toErase.push_back(numOldArguments + idx);
760441b672bSGiuseppe Rossini     }
761441b672bSGiuseppe Rossini   }
762441b672bSGiuseppe Rossini 
763441b672bSGiuseppe Rossini   // Erase the block's redundant arguments.
764441b672bSGiuseppe Rossini   for (unsigned idxToErase : llvm::reverse(toErase))
765441b672bSGiuseppe Rossini     block->eraseArgument(idxToErase);
766441b672bSGiuseppe Rossini   return newArgumentsPruned;
767441b672bSGiuseppe Rossini }
768441b672bSGiuseppe Rossini 
769d75a611aSRiver Riddle LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
770469c02d0SRiver Riddle   // Don't consider clusters that don't have blocks to merge.
771469c02d0SRiver Riddle   if (blocksToMerge.empty())
772469c02d0SRiver Riddle     return failure();
773469c02d0SRiver Riddle 
774469c02d0SRiver Riddle   Block *leaderBlock = leaderData.block;
775469c02d0SRiver Riddle   if (!operandsToMerge.empty()) {
776469c02d0SRiver Riddle     // If the cluster has operands to merge, verify that the predecessor
777469c02d0SRiver Riddle     // terminators of each of the blocks can have their successor operands
778469c02d0SRiver Riddle     // updated.
779469c02d0SRiver Riddle     // TODO: We could try and sub-partition this cluster if only some blocks
780469c02d0SRiver Riddle     // cause the mismatch.
781469c02d0SRiver Riddle     if (!ableToUpdatePredOperands(leaderBlock) ||
782469c02d0SRiver Riddle         !llvm::all_of(blocksToMerge, ableToUpdatePredOperands))
783469c02d0SRiver Riddle       return failure();
784469c02d0SRiver Riddle 
785469c02d0SRiver Riddle     // Collect the iterators for each of the blocks to merge. We will walk all
786469c02d0SRiver Riddle     // of the iterators at once to avoid operand index invalidation.
787469c02d0SRiver Riddle     SmallVector<Block::iterator, 2> blockIterators;
788469c02d0SRiver Riddle     blockIterators.reserve(blocksToMerge.size() + 1);
789469c02d0SRiver Riddle     blockIterators.push_back(leaderBlock->begin());
790469c02d0SRiver Riddle     for (Block *mergeBlock : blocksToMerge)
791469c02d0SRiver Riddle       blockIterators.push_back(mergeBlock->begin());
792469c02d0SRiver Riddle 
793469c02d0SRiver Riddle     // Update each of the predecessor terminators with the new arguments.
79428a11cc4SMehdi Amini     SmallVector<SmallVector<Value, 8>, 2> newArguments(
79528a11cc4SMehdi Amini         1 + blocksToMerge.size(),
79628a11cc4SMehdi Amini         SmallVector<Value, 8>(operandsToMerge.size()));
797469c02d0SRiver Riddle     unsigned curOpIndex = 0;
798441b672bSGiuseppe Rossini     unsigned numOldArguments = leaderBlock->getNumArguments();
799e4853be2SMehdi Amini     for (const auto &it : llvm::enumerate(operandsToMerge)) {
800469c02d0SRiver Riddle       unsigned nextOpOffset = it.value().first - curOpIndex;
801469c02d0SRiver Riddle       curOpIndex = it.value().first;
802469c02d0SRiver Riddle 
803469c02d0SRiver Riddle       // Process the operand for each of the block iterators.
804469c02d0SRiver Riddle       for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) {
805469c02d0SRiver Riddle         Block::iterator &blockIter = blockIterators[i];
806469c02d0SRiver Riddle         std::advance(blockIter, nextOpOffset);
807469c02d0SRiver Riddle         auto &operand = blockIter->getOpOperand(it.value().second);
80828a11cc4SMehdi Amini         newArguments[i][it.index()] = operand.get();
80928a11cc4SMehdi Amini 
810469c02d0SRiver Riddle         // Update the operand and insert an argument if this is the leader.
811e084679fSRiver Riddle         if (i == 0) {
81228a11cc4SMehdi Amini           Value operandVal = operand.get();
813e084679fSRiver Riddle           operand.set(leaderBlock->addArgument(operandVal.getType(),
814e084679fSRiver Riddle                                                operandVal.getLoc()));
815e084679fSRiver Riddle         }
816469c02d0SRiver Riddle       }
817469c02d0SRiver Riddle     }
818441b672bSGiuseppe Rossini 
819441b672bSGiuseppe Rossini     // Prune redundant arguments and update the leader block argument list
820441b672bSGiuseppe Rossini     newArguments = pruneRedundantArguments(newArguments, rewriter,
821441b672bSGiuseppe Rossini                                            numOldArguments, leaderBlock);
822441b672bSGiuseppe Rossini 
823469c02d0SRiver Riddle     // Update the predecessors for each of the blocks.
824469c02d0SRiver Riddle     auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
825469c02d0SRiver Riddle       for (auto predIt = block->pred_begin(), predE = block->pred_end();
826469c02d0SRiver Riddle            predIt != predE; ++predIt) {
827469c02d0SRiver Riddle         auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
828469c02d0SRiver Riddle         unsigned succIndex = predIt.getSuccessorIndex();
8290c789db5SMarkus Böck         branch.getSuccessorOperands(succIndex).append(
830469c02d0SRiver Riddle             newArguments[clusterIndex]);
831469c02d0SRiver Riddle       }
832469c02d0SRiver Riddle     };
833469c02d0SRiver Riddle     updatePredecessors(leaderBlock, /*clusterIndex=*/0);
834469c02d0SRiver Riddle     for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i)
835469c02d0SRiver Riddle       updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1);
836469c02d0SRiver Riddle   }
837469c02d0SRiver Riddle 
838469c02d0SRiver Riddle   // Replace all uses of the merged blocks with the leader and erase them.
839469c02d0SRiver Riddle   for (Block *block : blocksToMerge) {
840469c02d0SRiver Riddle     block->replaceAllUsesWith(leaderBlock);
841d75a611aSRiver Riddle     rewriter.eraseBlock(block);
842469c02d0SRiver Riddle   }
843469c02d0SRiver Riddle   return success();
844469c02d0SRiver Riddle }
845469c02d0SRiver Riddle 
846469c02d0SRiver Riddle /// Identify identical blocks within the given region and merge them, inserting
847469c02d0SRiver Riddle /// new block arguments as necessary. Returns success if any blocks were merged,
848469c02d0SRiver Riddle /// failure otherwise.
849d75a611aSRiver Riddle static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
850d75a611aSRiver Riddle                                           Region &region) {
851469c02d0SRiver Riddle   if (region.empty() || llvm::hasSingleElement(region))
852469c02d0SRiver Riddle     return failure();
853469c02d0SRiver Riddle 
854469c02d0SRiver Riddle   // Identify sets of blocks, other than the entry block, that branch to the
855469c02d0SRiver Riddle   // same successors. We will use these groups to create clusters of equivalent
856469c02d0SRiver Riddle   // blocks.
857469c02d0SRiver Riddle   DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors;
858469c02d0SRiver Riddle   for (Block &block : llvm::drop_begin(region, 1))
859469c02d0SRiver Riddle     matchingSuccessors[block.getSuccessors()].push_back(&block);
860469c02d0SRiver Riddle 
861469c02d0SRiver Riddle   bool mergedAnyBlocks = false;
862469c02d0SRiver Riddle   for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) {
863469c02d0SRiver Riddle     if (blocks.size() == 1)
864469c02d0SRiver Riddle       continue;
865469c02d0SRiver Riddle 
866469c02d0SRiver Riddle     SmallVector<BlockMergeCluster, 1> clusters;
867469c02d0SRiver Riddle     for (Block *block : blocks) {
868469c02d0SRiver Riddle       BlockEquivalenceData data(block);
869469c02d0SRiver Riddle 
870469c02d0SRiver Riddle       // Don't allow merging if this block has any regions.
871469c02d0SRiver Riddle       // TODO: Add support for regions if necessary.
872469c02d0SRiver Riddle       bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) {
873469c02d0SRiver Riddle         return llvm::any_of(op.getRegions(),
874469c02d0SRiver Riddle                             [](Region &region) { return !region.empty(); });
875469c02d0SRiver Riddle       });
876469c02d0SRiver Riddle       if (hasNonEmptyRegion)
877469c02d0SRiver Riddle         continue;
878469c02d0SRiver Riddle 
879c50fecaaSBen Howe       // Don't allow merging if this block's arguments are used outside of the
880c50fecaaSBen Howe       // original block.
881c50fecaaSBen Howe       bool argHasExternalUsers = llvm::any_of(
882c50fecaaSBen Howe           block->getArguments(), [block](mlir::BlockArgument &arg) {
883c50fecaaSBen Howe             return arg.isUsedOutsideOfBlock(block);
884c50fecaaSBen Howe           });
885c50fecaaSBen Howe       if (argHasExternalUsers)
886c50fecaaSBen Howe         continue;
887c50fecaaSBen Howe 
888469c02d0SRiver Riddle       // Try to add this block to an existing cluster.
889469c02d0SRiver Riddle       bool addedToCluster = false;
890469c02d0SRiver Riddle       for (auto &cluster : clusters)
891469c02d0SRiver Riddle         if ((addedToCluster = succeeded(cluster.addToCluster(data))))
892469c02d0SRiver Riddle           break;
893469c02d0SRiver Riddle       if (!addedToCluster)
894469c02d0SRiver Riddle         clusters.emplace_back(std::move(data));
895469c02d0SRiver Riddle     }
896469c02d0SRiver Riddle     for (auto &cluster : clusters)
897d75a611aSRiver Riddle       mergedAnyBlocks |= succeeded(cluster.merge(rewriter));
898469c02d0SRiver Riddle   }
899469c02d0SRiver Riddle 
900469c02d0SRiver Riddle   return success(mergedAnyBlocks);
901469c02d0SRiver Riddle }
902469c02d0SRiver Riddle 
903469c02d0SRiver Riddle /// Identify identical blocks within the given regions and merge them, inserting
904469c02d0SRiver Riddle /// new block arguments as necessary.
905d75a611aSRiver Riddle static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
906d75a611aSRiver Riddle                                           MutableArrayRef<Region> regions) {
907469c02d0SRiver Riddle   llvm::SmallSetVector<Region *, 1> worklist;
908469c02d0SRiver Riddle   for (auto &region : regions)
909469c02d0SRiver Riddle     worklist.insert(&region);
910469c02d0SRiver Riddle   bool anyChanged = false;
911469c02d0SRiver Riddle   while (!worklist.empty()) {
912469c02d0SRiver Riddle     Region *region = worklist.pop_back_val();
913d75a611aSRiver Riddle     if (succeeded(mergeIdenticalBlocks(rewriter, *region))) {
914469c02d0SRiver Riddle       worklist.insert(region);
915469c02d0SRiver Riddle       anyChanged = true;
916469c02d0SRiver Riddle     }
917469c02d0SRiver Riddle 
918469c02d0SRiver Riddle     // Add any nested regions to the worklist.
919469c02d0SRiver Riddle     for (Block &block : *region)
920469c02d0SRiver Riddle       for (auto &op : block)
921469c02d0SRiver Riddle         for (auto &nestedRegion : op.getRegions())
922469c02d0SRiver Riddle           worklist.insert(&nestedRegion);
923469c02d0SRiver Riddle   }
924469c02d0SRiver Riddle 
925469c02d0SRiver Riddle   return success(anyChanged);
926469c02d0SRiver Riddle }
927469c02d0SRiver Riddle 
928441b672bSGiuseppe Rossini /// If a block's argument is always the same across different invocations, then
929441b672bSGiuseppe Rossini /// drop the argument and use the value directly inside the block
930441b672bSGiuseppe Rossini static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
931441b672bSGiuseppe Rossini                                             Block &block) {
932441b672bSGiuseppe Rossini   SmallVector<size_t> argsToErase;
933441b672bSGiuseppe Rossini 
934441b672bSGiuseppe Rossini   // Go through the arguments of the block.
935441b672bSGiuseppe Rossini   for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) {
936441b672bSGiuseppe Rossini     bool sameArg = true;
937441b672bSGiuseppe Rossini     Value commonValue;
938441b672bSGiuseppe Rossini 
939441b672bSGiuseppe Rossini     // Go through the block predecessor and flag if they pass to the block
940441b672bSGiuseppe Rossini     // different values for the same argument.
941441b672bSGiuseppe Rossini     for (Block::pred_iterator predIt = block.pred_begin(),
942441b672bSGiuseppe Rossini                               predE = block.pred_end();
943441b672bSGiuseppe Rossini          predIt != predE; ++predIt) {
944441b672bSGiuseppe Rossini       auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
945441b672bSGiuseppe Rossini       if (!branch) {
946441b672bSGiuseppe Rossini         sameArg = false;
947441b672bSGiuseppe Rossini         break;
948441b672bSGiuseppe Rossini       }
949441b672bSGiuseppe Rossini       unsigned succIndex = predIt.getSuccessorIndex();
950441b672bSGiuseppe Rossini       SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
951441b672bSGiuseppe Rossini       auto branchOperands = succOperands.getForwardedOperands();
952441b672bSGiuseppe Rossini       if (!commonValue) {
953441b672bSGiuseppe Rossini         commonValue = branchOperands[argIdx];
954441b672bSGiuseppe Rossini         continue;
955441b672bSGiuseppe Rossini       }
956441b672bSGiuseppe Rossini       if (branchOperands[argIdx] != commonValue) {
957441b672bSGiuseppe Rossini         sameArg = false;
958441b672bSGiuseppe Rossini         break;
959441b672bSGiuseppe Rossini       }
960441b672bSGiuseppe Rossini     }
961441b672bSGiuseppe Rossini 
962441b672bSGiuseppe Rossini     // If they are passing the same value, drop the argument.
963441b672bSGiuseppe Rossini     if (commonValue && sameArg) {
964441b672bSGiuseppe Rossini       argsToErase.push_back(argIdx);
965441b672bSGiuseppe Rossini 
966441b672bSGiuseppe Rossini       // Remove the argument from the block.
967441b672bSGiuseppe Rossini       rewriter.replaceAllUsesWith(blockOperand, commonValue);
968441b672bSGiuseppe Rossini     }
969441b672bSGiuseppe Rossini   }
970441b672bSGiuseppe Rossini 
971441b672bSGiuseppe Rossini   // Remove the arguments.
972441b672bSGiuseppe Rossini   for (size_t argIdx : llvm::reverse(argsToErase)) {
973441b672bSGiuseppe Rossini     block.eraseArgument(argIdx);
974441b672bSGiuseppe Rossini 
975441b672bSGiuseppe Rossini     // Remove the argument from the branch ops.
976441b672bSGiuseppe Rossini     for (auto predIt = block.pred_begin(), predE = block.pred_end();
977441b672bSGiuseppe Rossini          predIt != predE; ++predIt) {
978441b672bSGiuseppe Rossini       auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
979441b672bSGiuseppe Rossini       unsigned succIndex = predIt.getSuccessorIndex();
980441b672bSGiuseppe Rossini       SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
981441b672bSGiuseppe Rossini       succOperands.erase(argIdx);
982441b672bSGiuseppe Rossini     }
983441b672bSGiuseppe Rossini   }
984441b672bSGiuseppe Rossini   return success(!argsToErase.empty());
985441b672bSGiuseppe Rossini }
986441b672bSGiuseppe Rossini 
987441b672bSGiuseppe Rossini /// This optimization drops redundant argument to blocks. I.e., if a given
988441b672bSGiuseppe Rossini /// argument to a block receives the same value from each of the block
989441b672bSGiuseppe Rossini /// predecessors, we can remove the argument from the block and use directly the
990441b672bSGiuseppe Rossini /// original value. This is a simple example:
991441b672bSGiuseppe Rossini ///
992441b672bSGiuseppe Rossini /// %cond = llvm.call @rand() : () -> i1
993441b672bSGiuseppe Rossini /// %val0 = llvm.mlir.constant(1 : i64) : i64
994441b672bSGiuseppe Rossini /// %val1 = llvm.mlir.constant(2 : i64) : i64
995441b672bSGiuseppe Rossini /// %val2 = llvm.mlir.constant(3 : i64) : i64
996441b672bSGiuseppe Rossini /// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
997441b672bSGiuseppe Rossini /// : i64)
998441b672bSGiuseppe Rossini ///
999441b672bSGiuseppe Rossini /// ^bb1(%arg0 : i64, %arg1 : i64):
1000441b672bSGiuseppe Rossini ///    llvm.call @foo(%arg0, %arg1)
1001441b672bSGiuseppe Rossini ///
1002441b672bSGiuseppe Rossini /// The previous IR can be rewritten as:
1003441b672bSGiuseppe Rossini /// %cond = llvm.call @rand() : () -> i1
1004441b672bSGiuseppe Rossini /// %val0 = llvm.mlir.constant(1 : i64) : i64
1005441b672bSGiuseppe Rossini /// %val1 = llvm.mlir.constant(2 : i64) : i64
1006441b672bSGiuseppe Rossini /// %val2 = llvm.mlir.constant(3 : i64) : i64
1007441b672bSGiuseppe Rossini /// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
1008441b672bSGiuseppe Rossini ///
1009441b672bSGiuseppe Rossini /// ^bb1(%arg0 : i64):
1010441b672bSGiuseppe Rossini ///    llvm.call @foo(%val0, %arg0)
1011441b672bSGiuseppe Rossini ///
1012441b672bSGiuseppe Rossini static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
1013441b672bSGiuseppe Rossini                                             MutableArrayRef<Region> regions) {
1014441b672bSGiuseppe Rossini   llvm::SmallSetVector<Region *, 1> worklist;
1015441b672bSGiuseppe Rossini   for (Region &region : regions)
1016441b672bSGiuseppe Rossini     worklist.insert(&region);
1017441b672bSGiuseppe Rossini   bool anyChanged = false;
1018441b672bSGiuseppe Rossini   while (!worklist.empty()) {
1019441b672bSGiuseppe Rossini     Region *region = worklist.pop_back_val();
1020441b672bSGiuseppe Rossini 
1021441b672bSGiuseppe Rossini     // Add any nested regions to the worklist.
1022441b672bSGiuseppe Rossini     for (Block &block : *region) {
1023a0241e71SGiuseppe Rossini       anyChanged =
1024a0241e71SGiuseppe Rossini           succeeded(dropRedundantArguments(rewriter, block)) || anyChanged;
1025441b672bSGiuseppe Rossini 
1026441b672bSGiuseppe Rossini       for (Operation &op : block)
1027441b672bSGiuseppe Rossini         for (Region &nestedRegion : op.getRegions())
1028441b672bSGiuseppe Rossini           worklist.insert(&nestedRegion);
1029441b672bSGiuseppe Rossini     }
1030441b672bSGiuseppe Rossini   }
1031441b672bSGiuseppe Rossini   return success(anyChanged);
1032441b672bSGiuseppe Rossini }
1033441b672bSGiuseppe Rossini 
1034469c02d0SRiver Riddle //===----------------------------------------------------------------------===//
1035fafb708bSRiver Riddle // Region Simplification
1036fafb708bSRiver Riddle //===----------------------------------------------------------------------===//
1037fafb708bSRiver Riddle 
1038fafb708bSRiver Riddle /// Run a set of structural simplifications over the given regions. This
1039fafb708bSRiver Riddle /// includes transformations like unreachable block elimination, dead argument
1040fafb708bSRiver Riddle /// elimination, as well as some other DCE. This function returns success if any
1041fafb708bSRiver Riddle /// of the regions were simplified, failure otherwise.
1042d75a611aSRiver Riddle LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
1043a506279eSMehdi Amini                                     MutableArrayRef<Region> regions,
1044a506279eSMehdi Amini                                     bool mergeBlocks) {
1045d75a611aSRiver Riddle   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
1046d75a611aSRiver Riddle   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
1047a506279eSMehdi Amini   bool mergedIdenticalBlocks = false;
1048441b672bSGiuseppe Rossini   bool droppedRedundantArguments = false;
1049441b672bSGiuseppe Rossini   if (mergeBlocks) {
1050a506279eSMehdi Amini     mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
1051441b672bSGiuseppe Rossini     droppedRedundantArguments =
1052441b672bSGiuseppe Rossini         succeeded(dropRedundantArguments(rewriter, regions));
1053441b672bSGiuseppe Rossini   }
1054469c02d0SRiver Riddle   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
1055441b672bSGiuseppe Rossini                  mergedIdenticalBlocks || droppedRedundantArguments);
1056fafb708bSRiver Riddle }
1057