xref: /llvm-project/mlir/lib/Transforms/Utils/RegionUtils.cpp (revision 6a5a64c56bc6c7183935367d3cf915ccdd103882)
1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/RegionUtils.h"
10 #include "mlir/Analysis/TopologicalSortUtils.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/IRMapping.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/RegionGraphTraits.h"
16 #include "mlir/IR/Value.h"
17 #include "mlir/Interfaces/ControlFlowInterfaces.h"
18 #include "mlir/Interfaces/SideEffectInterfaces.h"
19 
20 #include "llvm/ADT/DepthFirstIterator.h"
21 #include "llvm/ADT/PostOrderIterator.h"
22 
23 #include <deque>
24 
25 using namespace mlir;
26 
27 void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
28                                       Region &region) {
29   for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
30     if (region.isAncestor(use.getOwner()->getParentRegion()))
31       use.set(replacement);
32   }
33 }
34 
35 void mlir::visitUsedValuesDefinedAbove(
36     Region &region, Region &limit, function_ref<void(OpOperand *)> callback) {
37   assert(limit.isAncestor(&region) &&
38          "expected isolation limit to be an ancestor of the given region");
39 
40   // Collect proper ancestors of `limit` upfront to avoid traversing the region
41   // tree for every value.
42   SmallPtrSet<Region *, 4> properAncestors;
43   for (auto *reg = limit.getParentRegion(); reg != nullptr;
44        reg = reg->getParentRegion()) {
45     properAncestors.insert(reg);
46   }
47 
48   region.walk([callback, &properAncestors](Operation *op) {
49     for (OpOperand &operand : op->getOpOperands())
50       // Callback on values defined in a proper ancestor of region.
51       if (properAncestors.count(operand.get().getParentRegion()))
52         callback(&operand);
53   });
54 }
55 
56 void mlir::visitUsedValuesDefinedAbove(
57     MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) {
58   for (Region &region : regions)
59     visitUsedValuesDefinedAbove(region, region, callback);
60 }
61 
62 void mlir::getUsedValuesDefinedAbove(Region &region, Region &limit,
63                                      SetVector<Value> &values) {
64   visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
65     values.insert(operand->get());
66   });
67 }
68 
69 void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
70                                      SetVector<Value> &values) {
71   for (Region &region : regions)
72     getUsedValuesDefinedAbove(region, region, values);
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // Make block isolated from above.
77 //===----------------------------------------------------------------------===//
78 
79 SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
80     RewriterBase &rewriter, Region &region,
81     llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion) {
82 
83   // Get initial list of values used within region but defined above.
84   llvm::SetVector<Value> initialCapturedValues;
85   mlir::getUsedValuesDefinedAbove(region, initialCapturedValues);
86 
87   std::deque<Value> worklist(initialCapturedValues.begin(),
88                              initialCapturedValues.end());
89   llvm::DenseSet<Value> visited;
90   llvm::DenseSet<Operation *> visitedOps;
91 
92   llvm::SetVector<Value> finalCapturedValues;
93   SmallVector<Operation *> clonedOperations;
94   while (!worklist.empty()) {
95     Value currValue = worklist.front();
96     worklist.pop_front();
97     if (visited.count(currValue))
98       continue;
99     visited.insert(currValue);
100 
101     Operation *definingOp = currValue.getDefiningOp();
102     if (!definingOp || visitedOps.count(definingOp)) {
103       finalCapturedValues.insert(currValue);
104       continue;
105     }
106     visitedOps.insert(definingOp);
107 
108     if (!cloneOperationIntoRegion(definingOp)) {
109       // Defining operation isnt cloned, so add the current value to final
110       // captured values list.
111       finalCapturedValues.insert(currValue);
112       continue;
113     }
114 
115     // Add all operands of the operation to the worklist and mark the op as to
116     // be cloned.
117     for (Value operand : definingOp->getOperands()) {
118       if (visited.count(operand))
119         continue;
120       worklist.push_back(operand);
121     }
122     clonedOperations.push_back(definingOp);
123   }
124 
125   // The operations to be cloned need to be ordered in topological order
126   // so that they can be cloned into the region without violating use-def
127   // chains.
128   mlir::computeTopologicalSorting(clonedOperations);
129 
130   OpBuilder::InsertionGuard g(rewriter);
131   // Collect types of existing block
132   Block *entryBlock = &region.front();
133   SmallVector<Type> newArgTypes =
134       llvm::to_vector(entryBlock->getArgumentTypes());
135   SmallVector<Location> newArgLocs = llvm::to_vector(llvm::map_range(
136       entryBlock->getArguments(), [](BlockArgument b) { return b.getLoc(); }));
137 
138   // Append the types of the captured values.
139   for (auto value : finalCapturedValues) {
140     newArgTypes.push_back(value.getType());
141     newArgLocs.push_back(value.getLoc());
142   }
143 
144   // Create a new entry block.
145   Block *newEntryBlock =
146       rewriter.createBlock(&region, region.begin(), newArgTypes, newArgLocs);
147   auto newEntryBlockArgs = newEntryBlock->getArguments();
148 
149   // Create a mapping between the captured values and the new arguments added.
150   IRMapping map;
151   auto replaceIfFn = [&](OpOperand &use) {
152     return use.getOwner()->getBlock()->getParent() == &region;
153   };
154   for (auto [arg, capturedVal] :
155        llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()),
156                  finalCapturedValues)) {
157     map.map(capturedVal, arg);
158     rewriter.replaceUsesWithIf(capturedVal, arg, replaceIfFn);
159   }
160   rewriter.setInsertionPointToStart(newEntryBlock);
161   for (auto *clonedOp : clonedOperations) {
162     Operation *newOp = rewriter.clone(*clonedOp, map);
163     rewriter.replaceOpUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn);
164   }
165   rewriter.mergeBlocks(
166       entryBlock, newEntryBlock,
167       newEntryBlock->getArguments().take_front(entryBlock->getNumArguments()));
168   return llvm::to_vector(finalCapturedValues);
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // Unreachable Block Elimination
173 //===----------------------------------------------------------------------===//
174 
175 /// Erase the unreachable blocks within the provided regions. Returns success
176 /// if any blocks were erased, failure otherwise.
177 // TODO: We could likely merge this with the DCE algorithm below.
178 LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
179                                            MutableArrayRef<Region> regions) {
180   // Set of blocks found to be reachable within a given region.
181   llvm::df_iterator_default_set<Block *, 16> reachable;
182   // If any blocks were found to be dead.
183   bool erasedDeadBlocks = false;
184 
185   SmallVector<Region *, 1> worklist;
186   worklist.reserve(regions.size());
187   for (Region &region : regions)
188     worklist.push_back(&region);
189   while (!worklist.empty()) {
190     Region *region = worklist.pop_back_val();
191     if (region->empty())
192       continue;
193 
194     // If this is a single block region, just collect the nested regions.
195     if (std::next(region->begin()) == region->end()) {
196       for (Operation &op : region->front())
197         for (Region &region : op.getRegions())
198           worklist.push_back(&region);
199       continue;
200     }
201 
202     // Mark all reachable blocks.
203     reachable.clear();
204     for (Block *block : depth_first_ext(&region->front(), reachable))
205       (void)block /* Mark all reachable blocks */;
206 
207     // Collect all of the dead blocks and push the live regions onto the
208     // worklist.
209     for (Block &block : llvm::make_early_inc_range(*region)) {
210       if (!reachable.count(&block)) {
211         block.dropAllDefinedValueUses();
212         rewriter.eraseBlock(&block);
213         erasedDeadBlocks = true;
214         continue;
215       }
216 
217       // Walk any regions within this block.
218       for (Operation &op : block)
219         for (Region &region : op.getRegions())
220           worklist.push_back(&region);
221     }
222   }
223 
224   return success(erasedDeadBlocks);
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // Dead Code Elimination
229 //===----------------------------------------------------------------------===//
230 
231 namespace {
232 /// Data structure used to track which values have already been proved live.
233 ///
234 /// Because Operation's can have multiple results, this data structure tracks
235 /// liveness for both Value's and Operation's to avoid having to look through
236 /// all Operation results when analyzing a use.
237 ///
238 /// This data structure essentially tracks the dataflow lattice.
239 /// The set of values/ops proved live increases monotonically to a fixed-point.
240 class LiveMap {
241 public:
242   /// Value methods.
243   bool wasProvenLive(Value value) {
244     // TODO: For results that are removable, e.g. for region based control flow,
245     // we could allow for these values to be tracked independently.
246     if (OpResult result = dyn_cast<OpResult>(value))
247       return wasProvenLive(result.getOwner());
248     return wasProvenLive(cast<BlockArgument>(value));
249   }
250   bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); }
251   void setProvedLive(Value value) {
252     // TODO: For results that are removable, e.g. for region based control flow,
253     // we could allow for these values to be tracked independently.
254     if (OpResult result = dyn_cast<OpResult>(value))
255       return setProvedLive(result.getOwner());
256     setProvedLive(cast<BlockArgument>(value));
257   }
258   void setProvedLive(BlockArgument arg) {
259     changed |= liveValues.insert(arg).second;
260   }
261 
262   /// Operation methods.
263   bool wasProvenLive(Operation *op) { return liveOps.count(op); }
264   void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; }
265 
266   /// Methods for tracking if we have reached a fixed-point.
267   void resetChanged() { changed = false; }
268   bool hasChanged() { return changed; }
269 
270 private:
271   bool changed = false;
272   DenseSet<Value> liveValues;
273   DenseSet<Operation *> liveOps;
274 };
275 } // namespace
276 
277 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
278   Operation *owner = use.getOwner();
279   unsigned operandIndex = use.getOperandNumber();
280   // This pass generally treats all uses of an op as live if the op itself is
281   // considered live. However, for successor operands to terminators we need a
282   // finer-grained notion where we deduce liveness for operands individually.
283   // The reason for this is easiest to think about in terms of a classical phi
284   // node based SSA IR, where each successor operand is really an operand to a
285   // *separate* phi node, rather than all operands to the branch itself as with
286   // the block argument representation that MLIR uses.
287   //
288   // And similarly, because each successor operand is really an operand to a phi
289   // node, rather than to the terminator op itself, a terminator op can't e.g.
290   // "print" the value of a successor operand.
291   if (owner->hasTrait<OpTrait::IsTerminator>()) {
292     if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
293       if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
294         return !liveMap.wasProvenLive(*arg);
295     return false;
296   }
297   return false;
298 }
299 
300 static void processValue(Value value, LiveMap &liveMap) {
301   bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) {
302     if (isUseSpeciallyKnownDead(use, liveMap))
303       return false;
304     return liveMap.wasProvenLive(use.getOwner());
305   });
306   if (provedLive)
307     liveMap.setProvedLive(value);
308 }
309 
310 static void propagateLiveness(Region &region, LiveMap &liveMap);
311 
312 static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
313   // Terminators are always live.
314   liveMap.setProvedLive(op);
315 
316   // Check to see if we can reason about the successor operands and mutate them.
317   BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
318   if (!branchInterface) {
319     for (Block *successor : op->getSuccessors())
320       for (BlockArgument arg : successor->getArguments())
321         liveMap.setProvedLive(arg);
322     return;
323   }
324 
325   // If we can't reason about the operand to a successor, conservatively mark
326   // it as live.
327   for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
328     SuccessorOperands successorOperands =
329         branchInterface.getSuccessorOperands(i);
330     for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount();
331          opI != opE; ++opI)
332       liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI));
333   }
334 }
335 
336 static void propagateLiveness(Operation *op, LiveMap &liveMap) {
337   // Recurse on any regions the op has.
338   for (Region &region : op->getRegions())
339     propagateLiveness(region, liveMap);
340 
341   // Process terminator operations.
342   if (op->hasTrait<OpTrait::IsTerminator>())
343     return propagateTerminatorLiveness(op, liveMap);
344 
345   // Don't reprocess live operations.
346   if (liveMap.wasProvenLive(op))
347     return;
348 
349   // Process the op itself.
350   if (!wouldOpBeTriviallyDead(op))
351     return liveMap.setProvedLive(op);
352 
353   // If the op isn't intrinsically alive, check it's results.
354   for (Value value : op->getResults())
355     processValue(value, liveMap);
356 }
357 
358 static void propagateLiveness(Region &region, LiveMap &liveMap) {
359   if (region.empty())
360     return;
361 
362   for (Block *block : llvm::post_order(&region.front())) {
363     // We process block arguments after the ops in the block, to promote
364     // faster convergence to a fixed point (we try to visit uses before defs).
365     for (Operation &op : llvm::reverse(block->getOperations()))
366       propagateLiveness(&op, liveMap);
367 
368     // We currently do not remove entry block arguments, so there is no need to
369     // track their liveness.
370     // TODO: We could track these and enable removing dead operands/arguments
371     // from region control flow operations.
372     if (block->isEntryBlock())
373       continue;
374 
375     for (Value value : block->getArguments()) {
376       if (!liveMap.wasProvenLive(value))
377         processValue(value, liveMap);
378     }
379   }
380 }
381 
382 static void eraseTerminatorSuccessorOperands(Operation *terminator,
383                                              LiveMap &liveMap) {
384   BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
385   if (!branchOp)
386     return;
387 
388   for (unsigned succI = 0, succE = terminator->getNumSuccessors();
389        succI < succE; succI++) {
390     // Iterating successors in reverse is not strictly needed, since we
391     // aren't erasing any successors. But it is slightly more efficient
392     // since it will promote later operands of the terminator being erased
393     // first, reducing the quadratic-ness.
394     unsigned succ = succE - succI - 1;
395     SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ);
396     Block *successor = terminator->getSuccessor(succ);
397 
398     for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) {
399       // Iterating args in reverse is needed for correctness, to avoid
400       // shifting later args when earlier args are erased.
401       unsigned arg = argE - argI - 1;
402       if (!liveMap.wasProvenLive(successor->getArgument(arg)))
403         succOperands.erase(arg);
404     }
405   }
406 }
407 
408 static LogicalResult deleteDeadness(RewriterBase &rewriter,
409                                     MutableArrayRef<Region> regions,
410                                     LiveMap &liveMap) {
411   bool erasedAnything = false;
412   for (Region &region : regions) {
413     if (region.empty())
414       continue;
415     bool hasSingleBlock = llvm::hasSingleElement(region);
416 
417     // Delete every operation that is not live. Graph regions may have cycles
418     // in the use-def graph, so we must explicitly dropAllUses() from each
419     // operation as we erase it. Visiting the operations in post-order
420     // guarantees that in SSA CFG regions value uses are removed before defs,
421     // which makes dropAllUses() a no-op.
422     for (Block *block : llvm::post_order(&region.front())) {
423       if (!hasSingleBlock)
424         eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
425       for (Operation &childOp :
426            llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
427         if (!liveMap.wasProvenLive(&childOp)) {
428           erasedAnything = true;
429           childOp.dropAllUses();
430           rewriter.eraseOp(&childOp);
431         } else {
432           erasedAnything |= succeeded(
433               deleteDeadness(rewriter, childOp.getRegions(), liveMap));
434         }
435       }
436     }
437     // Delete block arguments.
438     // The entry block has an unknown contract with their enclosing block, so
439     // skip it.
440     for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
441       block.eraseArguments(
442           [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); });
443     }
444   }
445   return success(erasedAnything);
446 }
447 
448 // This function performs a simple dead code elimination algorithm over the
449 // given regions.
450 //
451 // The overall goal is to prove that Values are dead, which allows deleting ops
452 // and block arguments.
453 //
454 // This uses an optimistic algorithm that assumes everything is dead until
455 // proved otherwise, allowing it to delete recursively dead cycles.
456 //
457 // This is a simple fixed-point dataflow analysis algorithm on a lattice
458 // {Dead,Alive}. Because liveness flows backward, we generally try to
459 // iterate everything backward to speed up convergence to the fixed-point. This
460 // allows for being able to delete recursively dead cycles of the use-def graph,
461 // including block arguments.
462 //
463 // This function returns success if any operations or arguments were deleted,
464 // failure otherwise.
465 LogicalResult mlir::runRegionDCE(RewriterBase &rewriter,
466                                  MutableArrayRef<Region> regions) {
467   LiveMap liveMap;
468   do {
469     liveMap.resetChanged();
470 
471     for (Region &region : regions)
472       propagateLiveness(region, liveMap);
473   } while (liveMap.hasChanged());
474 
475   return deleteDeadness(rewriter, regions, liveMap);
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // Block Merging
480 //===----------------------------------------------------------------------===//
481 
482 //===----------------------------------------------------------------------===//
483 // BlockEquivalenceData
484 
485 namespace {
486 /// This class contains the information for comparing the equivalencies of two
487 /// blocks. Blocks are considered equivalent if they contain the same operations
488 /// in the same order. The only allowed divergence is for operands that come
489 /// from sources outside of the parent block, i.e. the uses of values produced
490 /// within the block must be equivalent.
491 ///   e.g.,
492 /// Equivalent:
493 ///  ^bb1(%arg0: i32)
494 ///    return %arg0, %foo : i32, i32
495 ///  ^bb2(%arg1: i32)
496 ///    return %arg1, %bar : i32, i32
497 /// Not Equivalent:
498 ///  ^bb1(%arg0: i32)
499 ///    return %foo, %arg0 : i32, i32
500 ///  ^bb2(%arg1: i32)
501 ///    return %arg1, %bar : i32, i32
502 struct BlockEquivalenceData {
503   BlockEquivalenceData(Block *block);
504 
505   /// Return the order index for the given value that is within the block of
506   /// this data.
507   unsigned getOrderOf(Value value) const;
508 
509   /// The block this data refers to.
510   Block *block;
511   /// A hash value for this block.
512   llvm::hash_code hash;
513   /// A map of result producing operations to their relative orders within this
514   /// block. The order of an operation is the number of defined values that are
515   /// produced within the block before this operation.
516   DenseMap<Operation *, unsigned> opOrderIndex;
517 };
518 } // namespace
519 
520 BlockEquivalenceData::BlockEquivalenceData(Block *block)
521     : block(block), hash(0) {
522   unsigned orderIt = block->getNumArguments();
523   for (Operation &op : *block) {
524     if (unsigned numResults = op.getNumResults()) {
525       opOrderIndex.try_emplace(&op, orderIt);
526       orderIt += numResults;
527     }
528     auto opHash = OperationEquivalence::computeHash(
529         &op, OperationEquivalence::ignoreHashValue,
530         OperationEquivalence::ignoreHashValue,
531         OperationEquivalence::IgnoreLocations);
532     hash = llvm::hash_combine(hash, opHash);
533   }
534 }
535 
536 unsigned BlockEquivalenceData::getOrderOf(Value value) const {
537   assert(value.getParentBlock() == block && "expected value of this block");
538 
539   // Arguments use the argument number as the order index.
540   if (BlockArgument arg = dyn_cast<BlockArgument>(value))
541     return arg.getArgNumber();
542 
543   // Otherwise, the result order is offset from the parent op's order.
544   OpResult result = cast<OpResult>(value);
545   auto opOrderIt = opOrderIndex.find(result.getDefiningOp());
546   assert(opOrderIt != opOrderIndex.end() && "expected op to have an order");
547   return opOrderIt->second + result.getResultNumber();
548 }
549 
550 //===----------------------------------------------------------------------===//
551 // BlockMergeCluster
552 
553 namespace {
554 /// This class represents a cluster of blocks to be merged together.
555 class BlockMergeCluster {
556 public:
557   BlockMergeCluster(BlockEquivalenceData &&leaderData)
558       : leaderData(std::move(leaderData)) {}
559 
560   /// Attempt to add the given block to this cluster. Returns success if the
561   /// block was merged, failure otherwise.
562   LogicalResult addToCluster(BlockEquivalenceData &blockData);
563 
564   /// Try to merge all of the blocks within this cluster into the leader block.
565   LogicalResult merge(RewriterBase &rewriter);
566 
567 private:
568   /// The equivalence data for the leader of the cluster.
569   BlockEquivalenceData leaderData;
570 
571   /// The set of blocks that can be merged into the leader.
572   llvm::SmallSetVector<Block *, 1> blocksToMerge;
573 
574   /// A set of operand+index pairs that correspond to operands that need to be
575   /// replaced by arguments when the cluster gets merged.
576   std::set<std::pair<int, int>> operandsToMerge;
577 };
578 } // namespace
579 
580 LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
581   if (leaderData.hash != blockData.hash)
582     return failure();
583   Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block;
584   if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes())
585     return failure();
586 
587   // A set of operands that mismatch between the leader and the new block.
588   SmallVector<std::pair<int, int>, 8> mismatchedOperands;
589   auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end();
590   auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
591   for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
592     // Check that the operations are equivalent.
593     if (!OperationEquivalence::isEquivalentTo(
594             &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
595             /*markEquivalent=*/nullptr,
596             OperationEquivalence::Flags::IgnoreLocations))
597       return failure();
598 
599     // Compare the operands of the two operations. If the operand is within
600     // the block, it must refer to the same operation.
601     auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands();
602     for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) {
603       Value lhsOperand = lhsOperands[operand];
604       Value rhsOperand = rhsOperands[operand];
605       if (lhsOperand == rhsOperand)
606         continue;
607       // Check that the types of the operands match.
608       if (lhsOperand.getType() != rhsOperand.getType())
609         return failure();
610 
611       // Check that these uses are both external, or both internal.
612       bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock;
613       bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock;
614       if (lhsIsInBlock != rhsIsInBlock)
615         return failure();
616       // Let the operands differ if they are defined in a different block. These
617       // will become new arguments if the blocks get merged.
618       if (!lhsIsInBlock) {
619 
620         // Check whether the operands aren't the result of an immediate
621         // predecessors terminator. In that case we are not able to use it as a
622         // successor operand when branching to the merged block as it does not
623         // dominate its producing operation.
624         auto isValidSuccessorArg = [](Block *block, Value operand) {
625           if (operand.getDefiningOp() !=
626               operand.getParentBlock()->getTerminator())
627             return true;
628           return !llvm::is_contained(block->getPredecessors(),
629                                      operand.getParentBlock());
630         };
631 
632         if (!isValidSuccessorArg(leaderBlock, lhsOperand) ||
633             !isValidSuccessorArg(mergeBlock, rhsOperand))
634           return failure();
635 
636         mismatchedOperands.emplace_back(opI, operand);
637         continue;
638       }
639 
640       // Otherwise, these operands must have the same logical order within the
641       // parent block.
642       if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand))
643         return failure();
644     }
645 
646     // If the lhs or rhs has external uses, the blocks cannot be merged as the
647     // merged version of this operation will not be either the lhs or rhs
648     // alone (thus semantically incorrect), but some mix dependending on which
649     // block preceeded this.
650     // TODO allow merging of operations when one block does not dominate the
651     // other
652     if (rhsIt->isUsedOutsideOfBlock(mergeBlock) ||
653         lhsIt->isUsedOutsideOfBlock(leaderBlock)) {
654       return failure();
655     }
656   }
657   // Make sure that the block sizes are equivalent.
658   if (lhsIt != lhsE || rhsIt != rhsE)
659     return failure();
660 
661   // If we get here, the blocks are equivalent and can be merged.
662   operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
663   blocksToMerge.insert(blockData.block);
664   return success();
665 }
666 
667 /// Returns true if the predecessor terminators of the given block can not have
668 /// their operands updated.
669 static bool ableToUpdatePredOperands(Block *block) {
670   for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
671     if (!isa<BranchOpInterface>((*it)->getTerminator()))
672       return false;
673   }
674   return true;
675 }
676 
677 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
678   // Don't consider clusters that don't have blocks to merge.
679   if (blocksToMerge.empty())
680     return failure();
681 
682   Block *leaderBlock = leaderData.block;
683   if (!operandsToMerge.empty()) {
684     // If the cluster has operands to merge, verify that the predecessor
685     // terminators of each of the blocks can have their successor operands
686     // updated.
687     // TODO: We could try and sub-partition this cluster if only some blocks
688     // cause the mismatch.
689     if (!ableToUpdatePredOperands(leaderBlock) ||
690         !llvm::all_of(blocksToMerge, ableToUpdatePredOperands))
691       return failure();
692 
693     // Collect the iterators for each of the blocks to merge. We will walk all
694     // of the iterators at once to avoid operand index invalidation.
695     SmallVector<Block::iterator, 2> blockIterators;
696     blockIterators.reserve(blocksToMerge.size() + 1);
697     blockIterators.push_back(leaderBlock->begin());
698     for (Block *mergeBlock : blocksToMerge)
699       blockIterators.push_back(mergeBlock->begin());
700 
701     // Update each of the predecessor terminators with the new arguments.
702     SmallVector<SmallVector<Value, 8>, 2> newArguments(
703         1 + blocksToMerge.size(),
704         SmallVector<Value, 8>(operandsToMerge.size()));
705     unsigned curOpIndex = 0;
706     for (const auto &it : llvm::enumerate(operandsToMerge)) {
707       unsigned nextOpOffset = it.value().first - curOpIndex;
708       curOpIndex = it.value().first;
709 
710       // Process the operand for each of the block iterators.
711       for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) {
712         Block::iterator &blockIter = blockIterators[i];
713         std::advance(blockIter, nextOpOffset);
714         auto &operand = blockIter->getOpOperand(it.value().second);
715         newArguments[i][it.index()] = operand.get();
716 
717         // Update the operand and insert an argument if this is the leader.
718         if (i == 0) {
719           Value operandVal = operand.get();
720           operand.set(leaderBlock->addArgument(operandVal.getType(),
721                                                operandVal.getLoc()));
722         }
723       }
724     }
725     // Update the predecessors for each of the blocks.
726     auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
727       for (auto predIt = block->pred_begin(), predE = block->pred_end();
728            predIt != predE; ++predIt) {
729         auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
730         unsigned succIndex = predIt.getSuccessorIndex();
731         branch.getSuccessorOperands(succIndex).append(
732             newArguments[clusterIndex]);
733       }
734     };
735     updatePredecessors(leaderBlock, /*clusterIndex=*/0);
736     for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i)
737       updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1);
738   }
739 
740   // Replace all uses of the merged blocks with the leader and erase them.
741   for (Block *block : blocksToMerge) {
742     block->replaceAllUsesWith(leaderBlock);
743     rewriter.eraseBlock(block);
744   }
745   return success();
746 }
747 
748 /// Identify identical blocks within the given region and merge them, inserting
749 /// new block arguments as necessary. Returns success if any blocks were merged,
750 /// failure otherwise.
751 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
752                                           Region &region) {
753   if (region.empty() || llvm::hasSingleElement(region))
754     return failure();
755 
756   // Identify sets of blocks, other than the entry block, that branch to the
757   // same successors. We will use these groups to create clusters of equivalent
758   // blocks.
759   DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors;
760   for (Block &block : llvm::drop_begin(region, 1))
761     matchingSuccessors[block.getSuccessors()].push_back(&block);
762 
763   bool mergedAnyBlocks = false;
764   for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) {
765     if (blocks.size() == 1)
766       continue;
767 
768     SmallVector<BlockMergeCluster, 1> clusters;
769     for (Block *block : blocks) {
770       BlockEquivalenceData data(block);
771 
772       // Don't allow merging if this block has any regions.
773       // TODO: Add support for regions if necessary.
774       bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) {
775         return llvm::any_of(op.getRegions(),
776                             [](Region &region) { return !region.empty(); });
777       });
778       if (hasNonEmptyRegion)
779         continue;
780 
781       // Try to add this block to an existing cluster.
782       bool addedToCluster = false;
783       for (auto &cluster : clusters)
784         if ((addedToCluster = succeeded(cluster.addToCluster(data))))
785           break;
786       if (!addedToCluster)
787         clusters.emplace_back(std::move(data));
788     }
789     for (auto &cluster : clusters)
790       mergedAnyBlocks |= succeeded(cluster.merge(rewriter));
791   }
792 
793   return success(mergedAnyBlocks);
794 }
795 
796 /// Identify identical blocks within the given regions and merge them, inserting
797 /// new block arguments as necessary.
798 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
799                                           MutableArrayRef<Region> regions) {
800   llvm::SmallSetVector<Region *, 1> worklist;
801   for (auto &region : regions)
802     worklist.insert(&region);
803   bool anyChanged = false;
804   while (!worklist.empty()) {
805     Region *region = worklist.pop_back_val();
806     if (succeeded(mergeIdenticalBlocks(rewriter, *region))) {
807       worklist.insert(region);
808       anyChanged = true;
809     }
810 
811     // Add any nested regions to the worklist.
812     for (Block &block : *region)
813       for (auto &op : block)
814         for (auto &nestedRegion : op.getRegions())
815           worklist.insert(&nestedRegion);
816   }
817 
818   return success(anyChanged);
819 }
820 
821 //===----------------------------------------------------------------------===//
822 // Region Simplification
823 //===----------------------------------------------------------------------===//
824 
825 /// Run a set of structural simplifications over the given regions. This
826 /// includes transformations like unreachable block elimination, dead argument
827 /// elimination, as well as some other DCE. This function returns success if any
828 /// of the regions were simplified, failure otherwise.
829 LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
830                                     MutableArrayRef<Region> regions,
831                                     bool mergeBlocks) {
832   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
833   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
834   bool mergedIdenticalBlocks = false;
835   if (mergeBlocks)
836     mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
837   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
838                  mergedIdenticalBlocks);
839 }
840