xref: /llvm-project/mlir/lib/Transforms/Utils/RegionUtils.cpp (revision ab6a66dbec61654d0962f6abf6d6c5b776937584)
1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/RegionUtils.h"
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/IRMapping.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/RegionGraphTraits.h"
15 #include "mlir/IR/Value.h"
16 #include "mlir/Interfaces/ControlFlowInterfaces.h"
17 #include "mlir/Interfaces/SideEffectInterfaces.h"
18 #include "mlir/Transforms/TopologicalSortUtils.h"
19 
20 #include "llvm/ADT/DepthFirstIterator.h"
21 #include "llvm/ADT/PostOrderIterator.h"
22 #include "llvm/ADT/SmallSet.h"
23 
24 #include <deque>
25 
26 using namespace mlir;
27 
28 void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
29                                       Region &region) {
30   for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
31     if (region.isAncestor(use.getOwner()->getParentRegion()))
32       use.set(replacement);
33   }
34 }
35 
36 void mlir::visitUsedValuesDefinedAbove(
37     Region &region, Region &limit, function_ref<void(OpOperand *)> callback) {
38   assert(limit.isAncestor(&region) &&
39          "expected isolation limit to be an ancestor of the given region");
40 
41   // Collect proper ancestors of `limit` upfront to avoid traversing the region
42   // tree for every value.
43   SmallPtrSet<Region *, 4> properAncestors;
44   for (auto *reg = limit.getParentRegion(); reg != nullptr;
45        reg = reg->getParentRegion()) {
46     properAncestors.insert(reg);
47   }
48 
49   region.walk([callback, &properAncestors](Operation *op) {
50     for (OpOperand &operand : op->getOpOperands())
51       // Callback on values defined in a proper ancestor of region.
52       if (properAncestors.count(operand.get().getParentRegion()))
53         callback(&operand);
54   });
55 }
56 
57 void mlir::visitUsedValuesDefinedAbove(
58     MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) {
59   for (Region &region : regions)
60     visitUsedValuesDefinedAbove(region, region, callback);
61 }
62 
63 void mlir::getUsedValuesDefinedAbove(Region &region, Region &limit,
64                                      SetVector<Value> &values) {
65   visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
66     values.insert(operand->get());
67   });
68 }
69 
70 void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
71                                      SetVector<Value> &values) {
72   for (Region &region : regions)
73     getUsedValuesDefinedAbove(region, region, values);
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // Make block isolated from above.
78 //===----------------------------------------------------------------------===//
79 
80 SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
81     RewriterBase &rewriter, Region &region,
82     llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion) {
83 
84   // Get initial list of values used within region but defined above.
85   llvm::SetVector<Value> initialCapturedValues;
86   mlir::getUsedValuesDefinedAbove(region, initialCapturedValues);
87 
88   std::deque<Value> worklist(initialCapturedValues.begin(),
89                              initialCapturedValues.end());
90   llvm::DenseSet<Value> visited;
91   llvm::DenseSet<Operation *> visitedOps;
92 
93   llvm::SetVector<Value> finalCapturedValues;
94   SmallVector<Operation *> clonedOperations;
95   while (!worklist.empty()) {
96     Value currValue = worklist.front();
97     worklist.pop_front();
98     if (visited.count(currValue))
99       continue;
100     visited.insert(currValue);
101 
102     Operation *definingOp = currValue.getDefiningOp();
103     if (!definingOp || visitedOps.count(definingOp)) {
104       finalCapturedValues.insert(currValue);
105       continue;
106     }
107     visitedOps.insert(definingOp);
108 
109     if (!cloneOperationIntoRegion(definingOp)) {
110       // Defining operation isnt cloned, so add the current value to final
111       // captured values list.
112       finalCapturedValues.insert(currValue);
113       continue;
114     }
115 
116     // Add all operands of the operation to the worklist and mark the op as to
117     // be cloned.
118     for (Value operand : definingOp->getOperands()) {
119       if (visited.count(operand))
120         continue;
121       worklist.push_back(operand);
122     }
123     clonedOperations.push_back(definingOp);
124   }
125 
126   // The operations to be cloned need to be ordered in topological order
127   // so that they can be cloned into the region without violating use-def
128   // chains.
129   mlir::computeTopologicalSorting(clonedOperations);
130 
131   OpBuilder::InsertionGuard g(rewriter);
132   // Collect types of existing block
133   Block *entryBlock = &region.front();
134   SmallVector<Type> newArgTypes =
135       llvm::to_vector(entryBlock->getArgumentTypes());
136   SmallVector<Location> newArgLocs = llvm::to_vector(llvm::map_range(
137       entryBlock->getArguments(), [](BlockArgument b) { return b.getLoc(); }));
138 
139   // Append the types of the captured values.
140   for (auto value : finalCapturedValues) {
141     newArgTypes.push_back(value.getType());
142     newArgLocs.push_back(value.getLoc());
143   }
144 
145   // Create a new entry block.
146   Block *newEntryBlock =
147       rewriter.createBlock(&region, region.begin(), newArgTypes, newArgLocs);
148   auto newEntryBlockArgs = newEntryBlock->getArguments();
149 
150   // Create a mapping between the captured values and the new arguments added.
151   IRMapping map;
152   auto replaceIfFn = [&](OpOperand &use) {
153     return use.getOwner()->getBlock()->getParent() == &region;
154   };
155   for (auto [arg, capturedVal] :
156        llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()),
157                  finalCapturedValues)) {
158     map.map(capturedVal, arg);
159     rewriter.replaceUsesWithIf(capturedVal, arg, replaceIfFn);
160   }
161   rewriter.setInsertionPointToStart(newEntryBlock);
162   for (auto clonedOp : clonedOperations) {
163     Operation *newOp = rewriter.clone(*clonedOp, map);
164     rewriter.replaceOpWithIf(clonedOp, newOp->getResults(), replaceIfFn);
165   }
166   rewriter.mergeBlocks(
167       entryBlock, newEntryBlock,
168       newEntryBlock->getArguments().take_front(entryBlock->getNumArguments()));
169   return llvm::to_vector(finalCapturedValues);
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // Unreachable Block Elimination
174 //===----------------------------------------------------------------------===//
175 
176 /// Erase the unreachable blocks within the provided regions. Returns success
177 /// if any blocks were erased, failure otherwise.
178 // TODO: We could likely merge this with the DCE algorithm below.
179 LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
180                                            MutableArrayRef<Region> regions) {
181   // Set of blocks found to be reachable within a given region.
182   llvm::df_iterator_default_set<Block *, 16> reachable;
183   // If any blocks were found to be dead.
184   bool erasedDeadBlocks = false;
185 
186   SmallVector<Region *, 1> worklist;
187   worklist.reserve(regions.size());
188   for (Region &region : regions)
189     worklist.push_back(&region);
190   while (!worklist.empty()) {
191     Region *region = worklist.pop_back_val();
192     if (region->empty())
193       continue;
194 
195     // If this is a single block region, just collect the nested regions.
196     if (std::next(region->begin()) == region->end()) {
197       for (Operation &op : region->front())
198         for (Region &region : op.getRegions())
199           worklist.push_back(&region);
200       continue;
201     }
202 
203     // Mark all reachable blocks.
204     reachable.clear();
205     for (Block *block : depth_first_ext(&region->front(), reachable))
206       (void)block /* Mark all reachable blocks */;
207 
208     // Collect all of the dead blocks and push the live regions onto the
209     // worklist.
210     for (Block &block : llvm::make_early_inc_range(*region)) {
211       if (!reachable.count(&block)) {
212         block.dropAllDefinedValueUses();
213         rewriter.eraseBlock(&block);
214         erasedDeadBlocks = true;
215         continue;
216       }
217 
218       // Walk any regions within this block.
219       for (Operation &op : block)
220         for (Region &region : op.getRegions())
221           worklist.push_back(&region);
222     }
223   }
224 
225   return success(erasedDeadBlocks);
226 }
227 
228 //===----------------------------------------------------------------------===//
229 // Dead Code Elimination
230 //===----------------------------------------------------------------------===//
231 
232 namespace {
233 /// Data structure used to track which values have already been proved live.
234 ///
235 /// Because Operation's can have multiple results, this data structure tracks
236 /// liveness for both Value's and Operation's to avoid having to look through
237 /// all Operation results when analyzing a use.
238 ///
239 /// This data structure essentially tracks the dataflow lattice.
240 /// The set of values/ops proved live increases monotonically to a fixed-point.
241 class LiveMap {
242 public:
243   /// Value methods.
244   bool wasProvenLive(Value value) {
245     // TODO: For results that are removable, e.g. for region based control flow,
246     // we could allow for these values to be tracked independently.
247     if (OpResult result = dyn_cast<OpResult>(value))
248       return wasProvenLive(result.getOwner());
249     return wasProvenLive(cast<BlockArgument>(value));
250   }
251   bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); }
252   void setProvedLive(Value value) {
253     // TODO: For results that are removable, e.g. for region based control flow,
254     // we could allow for these values to be tracked independently.
255     if (OpResult result = dyn_cast<OpResult>(value))
256       return setProvedLive(result.getOwner());
257     setProvedLive(cast<BlockArgument>(value));
258   }
259   void setProvedLive(BlockArgument arg) {
260     changed |= liveValues.insert(arg).second;
261   }
262 
263   /// Operation methods.
264   bool wasProvenLive(Operation *op) { return liveOps.count(op); }
265   void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; }
266 
267   /// Methods for tracking if we have reached a fixed-point.
268   void resetChanged() { changed = false; }
269   bool hasChanged() { return changed; }
270 
271 private:
272   bool changed = false;
273   DenseSet<Value> liveValues;
274   DenseSet<Operation *> liveOps;
275 };
276 } // namespace
277 
278 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
279   Operation *owner = use.getOwner();
280   unsigned operandIndex = use.getOperandNumber();
281   // This pass generally treats all uses of an op as live if the op itself is
282   // considered live. However, for successor operands to terminators we need a
283   // finer-grained notion where we deduce liveness for operands individually.
284   // The reason for this is easiest to think about in terms of a classical phi
285   // node based SSA IR, where each successor operand is really an operand to a
286   // *separate* phi node, rather than all operands to the branch itself as with
287   // the block argument representation that MLIR uses.
288   //
289   // And similarly, because each successor operand is really an operand to a phi
290   // node, rather than to the terminator op itself, a terminator op can't e.g.
291   // "print" the value of a successor operand.
292   if (owner->hasTrait<OpTrait::IsTerminator>()) {
293     if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
294       if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
295         return !liveMap.wasProvenLive(*arg);
296     return false;
297   }
298   return false;
299 }
300 
301 static void processValue(Value value, LiveMap &liveMap) {
302   bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) {
303     if (isUseSpeciallyKnownDead(use, liveMap))
304       return false;
305     return liveMap.wasProvenLive(use.getOwner());
306   });
307   if (provedLive)
308     liveMap.setProvedLive(value);
309 }
310 
311 static void propagateLiveness(Region &region, LiveMap &liveMap);
312 
313 static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
314   // Terminators are always live.
315   liveMap.setProvedLive(op);
316 
317   // Check to see if we can reason about the successor operands and mutate them.
318   BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
319   if (!branchInterface) {
320     for (Block *successor : op->getSuccessors())
321       for (BlockArgument arg : successor->getArguments())
322         liveMap.setProvedLive(arg);
323     return;
324   }
325 
326   // If we can't reason about the operand to a successor, conservatively mark
327   // it as live.
328   for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
329     SuccessorOperands successorOperands =
330         branchInterface.getSuccessorOperands(i);
331     for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount();
332          opI != opE; ++opI)
333       liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI));
334   }
335 }
336 
337 static void propagateLiveness(Operation *op, LiveMap &liveMap) {
338   // Recurse on any regions the op has.
339   for (Region &region : op->getRegions())
340     propagateLiveness(region, liveMap);
341 
342   // Process terminator operations.
343   if (op->hasTrait<OpTrait::IsTerminator>())
344     return propagateTerminatorLiveness(op, liveMap);
345 
346   // Don't reprocess live operations.
347   if (liveMap.wasProvenLive(op))
348     return;
349 
350   // Process the op itself.
351   if (!wouldOpBeTriviallyDead(op))
352     return liveMap.setProvedLive(op);
353 
354   // If the op isn't intrinsically alive, check it's results.
355   for (Value value : op->getResults())
356     processValue(value, liveMap);
357 }
358 
359 static void propagateLiveness(Region &region, LiveMap &liveMap) {
360   if (region.empty())
361     return;
362 
363   for (Block *block : llvm::post_order(&region.front())) {
364     // We process block arguments after the ops in the block, to promote
365     // faster convergence to a fixed point (we try to visit uses before defs).
366     for (Operation &op : llvm::reverse(block->getOperations()))
367       propagateLiveness(&op, liveMap);
368 
369     // We currently do not remove entry block arguments, so there is no need to
370     // track their liveness.
371     // TODO: We could track these and enable removing dead operands/arguments
372     // from region control flow operations.
373     if (block->isEntryBlock())
374       continue;
375 
376     for (Value value : block->getArguments()) {
377       if (!liveMap.wasProvenLive(value))
378         processValue(value, liveMap);
379     }
380   }
381 }
382 
383 static void eraseTerminatorSuccessorOperands(Operation *terminator,
384                                              LiveMap &liveMap) {
385   BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
386   if (!branchOp)
387     return;
388 
389   for (unsigned succI = 0, succE = terminator->getNumSuccessors();
390        succI < succE; succI++) {
391     // Iterating successors in reverse is not strictly needed, since we
392     // aren't erasing any successors. But it is slightly more efficient
393     // since it will promote later operands of the terminator being erased
394     // first, reducing the quadratic-ness.
395     unsigned succ = succE - succI - 1;
396     SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ);
397     Block *successor = terminator->getSuccessor(succ);
398 
399     for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) {
400       // Iterating args in reverse is needed for correctness, to avoid
401       // shifting later args when earlier args are erased.
402       unsigned arg = argE - argI - 1;
403       if (!liveMap.wasProvenLive(successor->getArgument(arg)))
404         succOperands.erase(arg);
405     }
406   }
407 }
408 
409 static LogicalResult deleteDeadness(RewriterBase &rewriter,
410                                     MutableArrayRef<Region> regions,
411                                     LiveMap &liveMap) {
412   bool erasedAnything = false;
413   for (Region &region : regions) {
414     if (region.empty())
415       continue;
416     bool hasSingleBlock = llvm::hasSingleElement(region);
417 
418     // Delete every operation that is not live. Graph regions may have cycles
419     // in the use-def graph, so we must explicitly dropAllUses() from each
420     // operation as we erase it. Visiting the operations in post-order
421     // guarantees that in SSA CFG regions value uses are removed before defs,
422     // which makes dropAllUses() a no-op.
423     for (Block *block : llvm::post_order(&region.front())) {
424       if (!hasSingleBlock)
425         eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
426       for (Operation &childOp :
427            llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
428         if (!liveMap.wasProvenLive(&childOp)) {
429           erasedAnything = true;
430           childOp.dropAllUses();
431           rewriter.eraseOp(&childOp);
432         } else {
433           erasedAnything |= succeeded(
434               deleteDeadness(rewriter, childOp.getRegions(), liveMap));
435         }
436       }
437     }
438     // Delete block arguments.
439     // The entry block has an unknown contract with their enclosing block, so
440     // skip it.
441     for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
442       block.eraseArguments(
443           [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); });
444     }
445   }
446   return success(erasedAnything);
447 }
448 
449 // This function performs a simple dead code elimination algorithm over the
450 // given regions.
451 //
452 // The overall goal is to prove that Values are dead, which allows deleting ops
453 // and block arguments.
454 //
455 // This uses an optimistic algorithm that assumes everything is dead until
456 // proved otherwise, allowing it to delete recursively dead cycles.
457 //
458 // This is a simple fixed-point dataflow analysis algorithm on a lattice
459 // {Dead,Alive}. Because liveness flows backward, we generally try to
460 // iterate everything backward to speed up convergence to the fixed-point. This
461 // allows for being able to delete recursively dead cycles of the use-def graph,
462 // including block arguments.
463 //
464 // This function returns success if any operations or arguments were deleted,
465 // failure otherwise.
466 LogicalResult mlir::runRegionDCE(RewriterBase &rewriter,
467                                  MutableArrayRef<Region> regions) {
468   LiveMap liveMap;
469   do {
470     liveMap.resetChanged();
471 
472     for (Region &region : regions)
473       propagateLiveness(region, liveMap);
474   } while (liveMap.hasChanged());
475 
476   return deleteDeadness(rewriter, regions, liveMap);
477 }
478 
479 //===----------------------------------------------------------------------===//
480 // Block Merging
481 //===----------------------------------------------------------------------===//
482 
483 //===----------------------------------------------------------------------===//
484 // BlockEquivalenceData
485 
486 namespace {
487 /// This class contains the information for comparing the equivalencies of two
488 /// blocks. Blocks are considered equivalent if they contain the same operations
489 /// in the same order. The only allowed divergence is for operands that come
490 /// from sources outside of the parent block, i.e. the uses of values produced
491 /// within the block must be equivalent.
492 ///   e.g.,
493 /// Equivalent:
494 ///  ^bb1(%arg0: i32)
495 ///    return %arg0, %foo : i32, i32
496 ///  ^bb2(%arg1: i32)
497 ///    return %arg1, %bar : i32, i32
498 /// Not Equivalent:
499 ///  ^bb1(%arg0: i32)
500 ///    return %foo, %arg0 : i32, i32
501 ///  ^bb2(%arg1: i32)
502 ///    return %arg1, %bar : i32, i32
503 struct BlockEquivalenceData {
504   BlockEquivalenceData(Block *block);
505 
506   /// Return the order index for the given value that is within the block of
507   /// this data.
508   unsigned getOrderOf(Value value) const;
509 
510   /// The block this data refers to.
511   Block *block;
512   /// A hash value for this block.
513   llvm::hash_code hash;
514   /// A map of result producing operations to their relative orders within this
515   /// block. The order of an operation is the number of defined values that are
516   /// produced within the block before this operation.
517   DenseMap<Operation *, unsigned> opOrderIndex;
518 };
519 } // namespace
520 
521 BlockEquivalenceData::BlockEquivalenceData(Block *block)
522     : block(block), hash(0) {
523   unsigned orderIt = block->getNumArguments();
524   for (Operation &op : *block) {
525     if (unsigned numResults = op.getNumResults()) {
526       opOrderIndex.try_emplace(&op, orderIt);
527       orderIt += numResults;
528     }
529     auto opHash = OperationEquivalence::computeHash(
530         &op, OperationEquivalence::ignoreHashValue,
531         OperationEquivalence::ignoreHashValue,
532         OperationEquivalence::IgnoreLocations);
533     hash = llvm::hash_combine(hash, opHash);
534   }
535 }
536 
537 unsigned BlockEquivalenceData::getOrderOf(Value value) const {
538   assert(value.getParentBlock() == block && "expected value of this block");
539 
540   // Arguments use the argument number as the order index.
541   if (BlockArgument arg = dyn_cast<BlockArgument>(value))
542     return arg.getArgNumber();
543 
544   // Otherwise, the result order is offset from the parent op's order.
545   OpResult result = cast<OpResult>(value);
546   auto opOrderIt = opOrderIndex.find(result.getDefiningOp());
547   assert(opOrderIt != opOrderIndex.end() && "expected op to have an order");
548   return opOrderIt->second + result.getResultNumber();
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // BlockMergeCluster
553 
554 namespace {
555 /// This class represents a cluster of blocks to be merged together.
556 class BlockMergeCluster {
557 public:
558   BlockMergeCluster(BlockEquivalenceData &&leaderData)
559       : leaderData(std::move(leaderData)) {}
560 
561   /// Attempt to add the given block to this cluster. Returns success if the
562   /// block was merged, failure otherwise.
563   LogicalResult addToCluster(BlockEquivalenceData &blockData);
564 
565   /// Try to merge all of the blocks within this cluster into the leader block.
566   LogicalResult merge(RewriterBase &rewriter);
567 
568 private:
569   /// The equivalence data for the leader of the cluster.
570   BlockEquivalenceData leaderData;
571 
572   /// The set of blocks that can be merged into the leader.
573   llvm::SmallSetVector<Block *, 1> blocksToMerge;
574 
575   /// A set of operand+index pairs that correspond to operands that need to be
576   /// replaced by arguments when the cluster gets merged.
577   std::set<std::pair<int, int>> operandsToMerge;
578 };
579 } // namespace
580 
581 LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
582   if (leaderData.hash != blockData.hash)
583     return failure();
584   Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block;
585   if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes())
586     return failure();
587 
588   // A set of operands that mismatch between the leader and the new block.
589   SmallVector<std::pair<int, int>, 8> mismatchedOperands;
590   auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end();
591   auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
592   for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
593     // Check that the operations are equivalent.
594     if (!OperationEquivalence::isEquivalentTo(
595             &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
596             /*markEquivalent=*/nullptr,
597             OperationEquivalence::Flags::IgnoreLocations))
598       return failure();
599 
600     // Compare the operands of the two operations. If the operand is within
601     // the block, it must refer to the same operation.
602     auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands();
603     for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) {
604       Value lhsOperand = lhsOperands[operand];
605       Value rhsOperand = rhsOperands[operand];
606       if (lhsOperand == rhsOperand)
607         continue;
608       // Check that the types of the operands match.
609       if (lhsOperand.getType() != rhsOperand.getType())
610         return failure();
611 
612       // Check that these uses are both external, or both internal.
613       bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock;
614       bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock;
615       if (lhsIsInBlock != rhsIsInBlock)
616         return failure();
617       // Let the operands differ if they are defined in a different block. These
618       // will become new arguments if the blocks get merged.
619       if (!lhsIsInBlock) {
620 
621         // Check whether the operands aren't the result of an immediate
622         // predecessors terminator. In that case we are not able to use it as a
623         // successor operand when branching to the merged block as it does not
624         // dominate its producing operation.
625         auto isValidSuccessorArg = [](Block *block, Value operand) {
626           if (operand.getDefiningOp() !=
627               operand.getParentBlock()->getTerminator())
628             return true;
629           return !llvm::is_contained(block->getPredecessors(),
630                                      operand.getParentBlock());
631         };
632 
633         if (!isValidSuccessorArg(leaderBlock, lhsOperand) ||
634             !isValidSuccessorArg(mergeBlock, rhsOperand))
635           return failure();
636 
637         mismatchedOperands.emplace_back(opI, operand);
638         continue;
639       }
640 
641       // Otherwise, these operands must have the same logical order within the
642       // parent block.
643       if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand))
644         return failure();
645     }
646 
647     // If the lhs or rhs has external uses, the blocks cannot be merged as the
648     // merged version of this operation will not be either the lhs or rhs
649     // alone (thus semantically incorrect), but some mix dependending on which
650     // block preceeded this.
651     // TODO allow merging of operations when one block does not dominate the
652     // other
653     if (rhsIt->isUsedOutsideOfBlock(mergeBlock) ||
654         lhsIt->isUsedOutsideOfBlock(leaderBlock)) {
655       return failure();
656     }
657   }
658   // Make sure that the block sizes are equivalent.
659   if (lhsIt != lhsE || rhsIt != rhsE)
660     return failure();
661 
662   // If we get here, the blocks are equivalent and can be merged.
663   operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
664   blocksToMerge.insert(blockData.block);
665   return success();
666 }
667 
668 /// Returns true if the predecessor terminators of the given block can not have
669 /// their operands updated.
670 static bool ableToUpdatePredOperands(Block *block) {
671   for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
672     if (!isa<BranchOpInterface>((*it)->getTerminator()))
673       return false;
674   }
675   return true;
676 }
677 
678 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
679   // Don't consider clusters that don't have blocks to merge.
680   if (blocksToMerge.empty())
681     return failure();
682 
683   Block *leaderBlock = leaderData.block;
684   if (!operandsToMerge.empty()) {
685     // If the cluster has operands to merge, verify that the predecessor
686     // terminators of each of the blocks can have their successor operands
687     // updated.
688     // TODO: We could try and sub-partition this cluster if only some blocks
689     // cause the mismatch.
690     if (!ableToUpdatePredOperands(leaderBlock) ||
691         !llvm::all_of(blocksToMerge, ableToUpdatePredOperands))
692       return failure();
693 
694     // Collect the iterators for each of the blocks to merge. We will walk all
695     // of the iterators at once to avoid operand index invalidation.
696     SmallVector<Block::iterator, 2> blockIterators;
697     blockIterators.reserve(blocksToMerge.size() + 1);
698     blockIterators.push_back(leaderBlock->begin());
699     for (Block *mergeBlock : blocksToMerge)
700       blockIterators.push_back(mergeBlock->begin());
701 
702     // Update each of the predecessor terminators with the new arguments.
703     SmallVector<SmallVector<Value, 8>, 2> newArguments(
704         1 + blocksToMerge.size(),
705         SmallVector<Value, 8>(operandsToMerge.size()));
706     unsigned curOpIndex = 0;
707     for (const auto &it : llvm::enumerate(operandsToMerge)) {
708       unsigned nextOpOffset = it.value().first - curOpIndex;
709       curOpIndex = it.value().first;
710 
711       // Process the operand for each of the block iterators.
712       for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) {
713         Block::iterator &blockIter = blockIterators[i];
714         std::advance(blockIter, nextOpOffset);
715         auto &operand = blockIter->getOpOperand(it.value().second);
716         newArguments[i][it.index()] = operand.get();
717 
718         // Update the operand and insert an argument if this is the leader.
719         if (i == 0) {
720           Value operandVal = operand.get();
721           operand.set(leaderBlock->addArgument(operandVal.getType(),
722                                                operandVal.getLoc()));
723         }
724       }
725     }
726     // Update the predecessors for each of the blocks.
727     auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
728       for (auto predIt = block->pred_begin(), predE = block->pred_end();
729            predIt != predE; ++predIt) {
730         auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
731         unsigned succIndex = predIt.getSuccessorIndex();
732         branch.getSuccessorOperands(succIndex).append(
733             newArguments[clusterIndex]);
734       }
735     };
736     updatePredecessors(leaderBlock, /*clusterIndex=*/0);
737     for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i)
738       updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1);
739   }
740 
741   // Replace all uses of the merged blocks with the leader and erase them.
742   for (Block *block : blocksToMerge) {
743     block->replaceAllUsesWith(leaderBlock);
744     rewriter.eraseBlock(block);
745   }
746   return success();
747 }
748 
749 /// Identify identical blocks within the given region and merge them, inserting
750 /// new block arguments as necessary. Returns success if any blocks were merged,
751 /// failure otherwise.
752 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
753                                           Region &region) {
754   if (region.empty() || llvm::hasSingleElement(region))
755     return failure();
756 
757   // Identify sets of blocks, other than the entry block, that branch to the
758   // same successors. We will use these groups to create clusters of equivalent
759   // blocks.
760   DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors;
761   for (Block &block : llvm::drop_begin(region, 1))
762     matchingSuccessors[block.getSuccessors()].push_back(&block);
763 
764   bool mergedAnyBlocks = false;
765   for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) {
766     if (blocks.size() == 1)
767       continue;
768 
769     SmallVector<BlockMergeCluster, 1> clusters;
770     for (Block *block : blocks) {
771       BlockEquivalenceData data(block);
772 
773       // Don't allow merging if this block has any regions.
774       // TODO: Add support for regions if necessary.
775       bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) {
776         return llvm::any_of(op.getRegions(),
777                             [](Region &region) { return !region.empty(); });
778       });
779       if (hasNonEmptyRegion)
780         continue;
781 
782       // Try to add this block to an existing cluster.
783       bool addedToCluster = false;
784       for (auto &cluster : clusters)
785         if ((addedToCluster = succeeded(cluster.addToCluster(data))))
786           break;
787       if (!addedToCluster)
788         clusters.emplace_back(std::move(data));
789     }
790     for (auto &cluster : clusters)
791       mergedAnyBlocks |= succeeded(cluster.merge(rewriter));
792   }
793 
794   return success(mergedAnyBlocks);
795 }
796 
797 /// Identify identical blocks within the given regions and merge them, inserting
798 /// new block arguments as necessary.
799 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
800                                           MutableArrayRef<Region> regions) {
801   llvm::SmallSetVector<Region *, 1> worklist;
802   for (auto &region : regions)
803     worklist.insert(&region);
804   bool anyChanged = false;
805   while (!worklist.empty()) {
806     Region *region = worklist.pop_back_val();
807     if (succeeded(mergeIdenticalBlocks(rewriter, *region))) {
808       worklist.insert(region);
809       anyChanged = true;
810     }
811 
812     // Add any nested regions to the worklist.
813     for (Block &block : *region)
814       for (auto &op : block)
815         for (auto &nestedRegion : op.getRegions())
816           worklist.insert(&nestedRegion);
817   }
818 
819   return success(anyChanged);
820 }
821 
822 //===----------------------------------------------------------------------===//
823 // Region Simplification
824 //===----------------------------------------------------------------------===//
825 
826 /// Run a set of structural simplifications over the given regions. This
827 /// includes transformations like unreachable block elimination, dead argument
828 /// elimination, as well as some other DCE. This function returns success if any
829 /// of the regions were simplified, failure otherwise.
830 LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
831                                     MutableArrayRef<Region> regions) {
832   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
833   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
834   bool mergedIdenticalBlocks =
835       succeeded(mergeIdenticalBlocks(rewriter, regions));
836   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
837                  mergedIdenticalBlocks);
838 }
839 
840 SetVector<Block *> mlir::getTopologicallySortedBlocks(Region &region) {
841   // For each block that has not been visited yet (i.e. that has no
842   // predecessors), add it to the list as well as its successors.
843   SetVector<Block *> blocks;
844   for (Block &b : region) {
845     if (blocks.count(&b) == 0) {
846       llvm::ReversePostOrderTraversal<Block *> traversal(&b);
847       blocks.insert(traversal.begin(), traversal.end());
848     }
849   }
850   assert(blocks.size() == region.getBlocks().size() &&
851          "some blocks are not sorted");
852 
853   return blocks;
854 }
855