xref: /llvm-project/mlir/lib/Transforms/Utils/RegionUtils.cpp (revision 1e5f32e81f96af45551dafb369279c6d55ac9b97)
1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/RegionUtils.h"
10 #include "mlir/Analysis/TopologicalSortUtils.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/IRMapping.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/RegionGraphTraits.h"
17 #include "mlir/IR/Value.h"
18 #include "mlir/Interfaces/ControlFlowInterfaces.h"
19 #include "mlir/Interfaces/SideEffectInterfaces.h"
20 #include "mlir/Support/LogicalResult.h"
21 
22 #include "llvm/ADT/DepthFirstIterator.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallSet.h"
26 
27 #include <deque>
28 #include <iterator>
29 
30 using namespace mlir;
31 
32 void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
33                                       Region &region) {
34   for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
35     if (region.isAncestor(use.getOwner()->getParentRegion()))
36       use.set(replacement);
37   }
38 }
39 
40 void mlir::visitUsedValuesDefinedAbove(
41     Region &region, Region &limit, function_ref<void(OpOperand *)> callback) {
42   assert(limit.isAncestor(&region) &&
43          "expected isolation limit to be an ancestor of the given region");
44 
45   // Collect proper ancestors of `limit` upfront to avoid traversing the region
46   // tree for every value.
47   SmallPtrSet<Region *, 4> properAncestors;
48   for (auto *reg = limit.getParentRegion(); reg != nullptr;
49        reg = reg->getParentRegion()) {
50     properAncestors.insert(reg);
51   }
52 
53   region.walk([callback, &properAncestors](Operation *op) {
54     for (OpOperand &operand : op->getOpOperands())
55       // Callback on values defined in a proper ancestor of region.
56       if (properAncestors.count(operand.get().getParentRegion()))
57         callback(&operand);
58   });
59 }
60 
61 void mlir::visitUsedValuesDefinedAbove(
62     MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) {
63   for (Region &region : regions)
64     visitUsedValuesDefinedAbove(region, region, callback);
65 }
66 
67 void mlir::getUsedValuesDefinedAbove(Region &region, Region &limit,
68                                      SetVector<Value> &values) {
69   visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
70     values.insert(operand->get());
71   });
72 }
73 
74 void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
75                                      SetVector<Value> &values) {
76   for (Region &region : regions)
77     getUsedValuesDefinedAbove(region, region, values);
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // Make block isolated from above.
82 //===----------------------------------------------------------------------===//
83 
84 SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
85     RewriterBase &rewriter, Region &region,
86     llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion) {
87 
88   // Get initial list of values used within region but defined above.
89   llvm::SetVector<Value> initialCapturedValues;
90   mlir::getUsedValuesDefinedAbove(region, initialCapturedValues);
91 
92   std::deque<Value> worklist(initialCapturedValues.begin(),
93                              initialCapturedValues.end());
94   llvm::DenseSet<Value> visited;
95   llvm::DenseSet<Operation *> visitedOps;
96 
97   llvm::SetVector<Value> finalCapturedValues;
98   SmallVector<Operation *> clonedOperations;
99   while (!worklist.empty()) {
100     Value currValue = worklist.front();
101     worklist.pop_front();
102     if (visited.count(currValue))
103       continue;
104     visited.insert(currValue);
105 
106     Operation *definingOp = currValue.getDefiningOp();
107     if (!definingOp || visitedOps.count(definingOp)) {
108       finalCapturedValues.insert(currValue);
109       continue;
110     }
111     visitedOps.insert(definingOp);
112 
113     if (!cloneOperationIntoRegion(definingOp)) {
114       // Defining operation isnt cloned, so add the current value to final
115       // captured values list.
116       finalCapturedValues.insert(currValue);
117       continue;
118     }
119 
120     // Add all operands of the operation to the worklist and mark the op as to
121     // be cloned.
122     for (Value operand : definingOp->getOperands()) {
123       if (visited.count(operand))
124         continue;
125       worklist.push_back(operand);
126     }
127     clonedOperations.push_back(definingOp);
128   }
129 
130   // The operations to be cloned need to be ordered in topological order
131   // so that they can be cloned into the region without violating use-def
132   // chains.
133   mlir::computeTopologicalSorting(clonedOperations);
134 
135   OpBuilder::InsertionGuard g(rewriter);
136   // Collect types of existing block
137   Block *entryBlock = &region.front();
138   SmallVector<Type> newArgTypes =
139       llvm::to_vector(entryBlock->getArgumentTypes());
140   SmallVector<Location> newArgLocs = llvm::to_vector(llvm::map_range(
141       entryBlock->getArguments(), [](BlockArgument b) { return b.getLoc(); }));
142 
143   // Append the types of the captured values.
144   for (auto value : finalCapturedValues) {
145     newArgTypes.push_back(value.getType());
146     newArgLocs.push_back(value.getLoc());
147   }
148 
149   // Create a new entry block.
150   Block *newEntryBlock =
151       rewriter.createBlock(&region, region.begin(), newArgTypes, newArgLocs);
152   auto newEntryBlockArgs = newEntryBlock->getArguments();
153 
154   // Create a mapping between the captured values and the new arguments added.
155   IRMapping map;
156   auto replaceIfFn = [&](OpOperand &use) {
157     return use.getOwner()->getBlock()->getParent() == &region;
158   };
159   for (auto [arg, capturedVal] :
160        llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()),
161                  finalCapturedValues)) {
162     map.map(capturedVal, arg);
163     rewriter.replaceUsesWithIf(capturedVal, arg, replaceIfFn);
164   }
165   rewriter.setInsertionPointToStart(newEntryBlock);
166   for (auto *clonedOp : clonedOperations) {
167     Operation *newOp = rewriter.clone(*clonedOp, map);
168     rewriter.replaceOpUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn);
169   }
170   rewriter.mergeBlocks(
171       entryBlock, newEntryBlock,
172       newEntryBlock->getArguments().take_front(entryBlock->getNumArguments()));
173   return llvm::to_vector(finalCapturedValues);
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // Unreachable Block Elimination
178 //===----------------------------------------------------------------------===//
179 
180 /// Erase the unreachable blocks within the provided regions. Returns success
181 /// if any blocks were erased, failure otherwise.
182 // TODO: We could likely merge this with the DCE algorithm below.
183 LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
184                                            MutableArrayRef<Region> regions) {
185   // Set of blocks found to be reachable within a given region.
186   llvm::df_iterator_default_set<Block *, 16> reachable;
187   // If any blocks were found to be dead.
188   bool erasedDeadBlocks = false;
189 
190   SmallVector<Region *, 1> worklist;
191   worklist.reserve(regions.size());
192   for (Region &region : regions)
193     worklist.push_back(&region);
194   while (!worklist.empty()) {
195     Region *region = worklist.pop_back_val();
196     if (region->empty())
197       continue;
198 
199     // If this is a single block region, just collect the nested regions.
200     if (std::next(region->begin()) == region->end()) {
201       for (Operation &op : region->front())
202         for (Region &region : op.getRegions())
203           worklist.push_back(&region);
204       continue;
205     }
206 
207     // Mark all reachable blocks.
208     reachable.clear();
209     for (Block *block : depth_first_ext(&region->front(), reachable))
210       (void)block /* Mark all reachable blocks */;
211 
212     // Collect all of the dead blocks and push the live regions onto the
213     // worklist.
214     for (Block &block : llvm::make_early_inc_range(*region)) {
215       if (!reachable.count(&block)) {
216         block.dropAllDefinedValueUses();
217         rewriter.eraseBlock(&block);
218         erasedDeadBlocks = true;
219         continue;
220       }
221 
222       // Walk any regions within this block.
223       for (Operation &op : block)
224         for (Region &region : op.getRegions())
225           worklist.push_back(&region);
226     }
227   }
228 
229   return success(erasedDeadBlocks);
230 }
231 
232 //===----------------------------------------------------------------------===//
233 // Dead Code Elimination
234 //===----------------------------------------------------------------------===//
235 
236 namespace {
237 /// Data structure used to track which values have already been proved live.
238 ///
239 /// Because Operation's can have multiple results, this data structure tracks
240 /// liveness for both Value's and Operation's to avoid having to look through
241 /// all Operation results when analyzing a use.
242 ///
243 /// This data structure essentially tracks the dataflow lattice.
244 /// The set of values/ops proved live increases monotonically to a fixed-point.
245 class LiveMap {
246 public:
247   /// Value methods.
248   bool wasProvenLive(Value value) {
249     // TODO: For results that are removable, e.g. for region based control flow,
250     // we could allow for these values to be tracked independently.
251     if (OpResult result = dyn_cast<OpResult>(value))
252       return wasProvenLive(result.getOwner());
253     return wasProvenLive(cast<BlockArgument>(value));
254   }
255   bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); }
256   void setProvedLive(Value value) {
257     // TODO: For results that are removable, e.g. for region based control flow,
258     // we could allow for these values to be tracked independently.
259     if (OpResult result = dyn_cast<OpResult>(value))
260       return setProvedLive(result.getOwner());
261     setProvedLive(cast<BlockArgument>(value));
262   }
263   void setProvedLive(BlockArgument arg) {
264     changed |= liveValues.insert(arg).second;
265   }
266 
267   /// Operation methods.
268   bool wasProvenLive(Operation *op) { return liveOps.count(op); }
269   void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; }
270 
271   /// Methods for tracking if we have reached a fixed-point.
272   void resetChanged() { changed = false; }
273   bool hasChanged() { return changed; }
274 
275 private:
276   bool changed = false;
277   DenseSet<Value> liveValues;
278   DenseSet<Operation *> liveOps;
279 };
280 } // namespace
281 
282 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
283   Operation *owner = use.getOwner();
284   unsigned operandIndex = use.getOperandNumber();
285   // This pass generally treats all uses of an op as live if the op itself is
286   // considered live. However, for successor operands to terminators we need a
287   // finer-grained notion where we deduce liveness for operands individually.
288   // The reason for this is easiest to think about in terms of a classical phi
289   // node based SSA IR, where each successor operand is really an operand to a
290   // *separate* phi node, rather than all operands to the branch itself as with
291   // the block argument representation that MLIR uses.
292   //
293   // And similarly, because each successor operand is really an operand to a phi
294   // node, rather than to the terminator op itself, a terminator op can't e.g.
295   // "print" the value of a successor operand.
296   if (owner->hasTrait<OpTrait::IsTerminator>()) {
297     if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
298       if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
299         return !liveMap.wasProvenLive(*arg);
300     return false;
301   }
302   return false;
303 }
304 
305 static void processValue(Value value, LiveMap &liveMap) {
306   bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) {
307     if (isUseSpeciallyKnownDead(use, liveMap))
308       return false;
309     return liveMap.wasProvenLive(use.getOwner());
310   });
311   if (provedLive)
312     liveMap.setProvedLive(value);
313 }
314 
315 static void propagateLiveness(Region &region, LiveMap &liveMap);
316 
317 static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
318   // Terminators are always live.
319   liveMap.setProvedLive(op);
320 
321   // Check to see if we can reason about the successor operands and mutate them.
322   BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
323   if (!branchInterface) {
324     for (Block *successor : op->getSuccessors())
325       for (BlockArgument arg : successor->getArguments())
326         liveMap.setProvedLive(arg);
327     return;
328   }
329 
330   // If we can't reason about the operand to a successor, conservatively mark
331   // it as live.
332   for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
333     SuccessorOperands successorOperands =
334         branchInterface.getSuccessorOperands(i);
335     for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount();
336          opI != opE; ++opI)
337       liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI));
338   }
339 }
340 
341 static void propagateLiveness(Operation *op, LiveMap &liveMap) {
342   // Recurse on any regions the op has.
343   for (Region &region : op->getRegions())
344     propagateLiveness(region, liveMap);
345 
346   // Process terminator operations.
347   if (op->hasTrait<OpTrait::IsTerminator>())
348     return propagateTerminatorLiveness(op, liveMap);
349 
350   // Don't reprocess live operations.
351   if (liveMap.wasProvenLive(op))
352     return;
353 
354   // Process the op itself.
355   if (!wouldOpBeTriviallyDead(op))
356     return liveMap.setProvedLive(op);
357 
358   // If the op isn't intrinsically alive, check it's results.
359   for (Value value : op->getResults())
360     processValue(value, liveMap);
361 }
362 
363 static void propagateLiveness(Region &region, LiveMap &liveMap) {
364   if (region.empty())
365     return;
366 
367   for (Block *block : llvm::post_order(&region.front())) {
368     // We process block arguments after the ops in the block, to promote
369     // faster convergence to a fixed point (we try to visit uses before defs).
370     for (Operation &op : llvm::reverse(block->getOperations()))
371       propagateLiveness(&op, liveMap);
372 
373     // We currently do not remove entry block arguments, so there is no need to
374     // track their liveness.
375     // TODO: We could track these and enable removing dead operands/arguments
376     // from region control flow operations.
377     if (block->isEntryBlock())
378       continue;
379 
380     for (Value value : block->getArguments()) {
381       if (!liveMap.wasProvenLive(value))
382         processValue(value, liveMap);
383     }
384   }
385 }
386 
387 static void eraseTerminatorSuccessorOperands(Operation *terminator,
388                                              LiveMap &liveMap) {
389   BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
390   if (!branchOp)
391     return;
392 
393   for (unsigned succI = 0, succE = terminator->getNumSuccessors();
394        succI < succE; succI++) {
395     // Iterating successors in reverse is not strictly needed, since we
396     // aren't erasing any successors. But it is slightly more efficient
397     // since it will promote later operands of the terminator being erased
398     // first, reducing the quadratic-ness.
399     unsigned succ = succE - succI - 1;
400     SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ);
401     Block *successor = terminator->getSuccessor(succ);
402 
403     for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) {
404       // Iterating args in reverse is needed for correctness, to avoid
405       // shifting later args when earlier args are erased.
406       unsigned arg = argE - argI - 1;
407       if (!liveMap.wasProvenLive(successor->getArgument(arg)))
408         succOperands.erase(arg);
409     }
410   }
411 }
412 
413 static LogicalResult deleteDeadness(RewriterBase &rewriter,
414                                     MutableArrayRef<Region> regions,
415                                     LiveMap &liveMap) {
416   bool erasedAnything = false;
417   for (Region &region : regions) {
418     if (region.empty())
419       continue;
420     bool hasSingleBlock = llvm::hasSingleElement(region);
421 
422     // Delete every operation that is not live. Graph regions may have cycles
423     // in the use-def graph, so we must explicitly dropAllUses() from each
424     // operation as we erase it. Visiting the operations in post-order
425     // guarantees that in SSA CFG regions value uses are removed before defs,
426     // which makes dropAllUses() a no-op.
427     for (Block *block : llvm::post_order(&region.front())) {
428       if (!hasSingleBlock)
429         eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
430       for (Operation &childOp :
431            llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
432         if (!liveMap.wasProvenLive(&childOp)) {
433           erasedAnything = true;
434           childOp.dropAllUses();
435           rewriter.eraseOp(&childOp);
436         } else {
437           erasedAnything |= succeeded(
438               deleteDeadness(rewriter, childOp.getRegions(), liveMap));
439         }
440       }
441     }
442     // Delete block arguments.
443     // The entry block has an unknown contract with their enclosing block, so
444     // skip it.
445     for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
446       block.eraseArguments(
447           [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); });
448     }
449   }
450   return success(erasedAnything);
451 }
452 
453 // This function performs a simple dead code elimination algorithm over the
454 // given regions.
455 //
456 // The overall goal is to prove that Values are dead, which allows deleting ops
457 // and block arguments.
458 //
459 // This uses an optimistic algorithm that assumes everything is dead until
460 // proved otherwise, allowing it to delete recursively dead cycles.
461 //
462 // This is a simple fixed-point dataflow analysis algorithm on a lattice
463 // {Dead,Alive}. Because liveness flows backward, we generally try to
464 // iterate everything backward to speed up convergence to the fixed-point. This
465 // allows for being able to delete recursively dead cycles of the use-def graph,
466 // including block arguments.
467 //
468 // This function returns success if any operations or arguments were deleted,
469 // failure otherwise.
470 LogicalResult mlir::runRegionDCE(RewriterBase &rewriter,
471                                  MutableArrayRef<Region> regions) {
472   LiveMap liveMap;
473   do {
474     liveMap.resetChanged();
475 
476     for (Region &region : regions)
477       propagateLiveness(region, liveMap);
478   } while (liveMap.hasChanged());
479 
480   return deleteDeadness(rewriter, regions, liveMap);
481 }
482 
483 //===----------------------------------------------------------------------===//
484 // Block Merging
485 //===----------------------------------------------------------------------===//
486 
487 //===----------------------------------------------------------------------===//
488 // BlockEquivalenceData
489 
490 namespace {
491 /// This class contains the information for comparing the equivalencies of two
492 /// blocks. Blocks are considered equivalent if they contain the same operations
493 /// in the same order. The only allowed divergence is for operands that come
494 /// from sources outside of the parent block, i.e. the uses of values produced
495 /// within the block must be equivalent.
496 ///   e.g.,
497 /// Equivalent:
498 ///  ^bb1(%arg0: i32)
499 ///    return %arg0, %foo : i32, i32
500 ///  ^bb2(%arg1: i32)
501 ///    return %arg1, %bar : i32, i32
502 /// Not Equivalent:
503 ///  ^bb1(%arg0: i32)
504 ///    return %foo, %arg0 : i32, i32
505 ///  ^bb2(%arg1: i32)
506 ///    return %arg1, %bar : i32, i32
507 struct BlockEquivalenceData {
508   BlockEquivalenceData(Block *block);
509 
510   /// Return the order index for the given value that is within the block of
511   /// this data.
512   unsigned getOrderOf(Value value) const;
513 
514   /// The block this data refers to.
515   Block *block;
516   /// A hash value for this block.
517   llvm::hash_code hash;
518   /// A map of result producing operations to their relative orders within this
519   /// block. The order of an operation is the number of defined values that are
520   /// produced within the block before this operation.
521   DenseMap<Operation *, unsigned> opOrderIndex;
522 };
523 } // namespace
524 
525 BlockEquivalenceData::BlockEquivalenceData(Block *block)
526     : block(block), hash(0) {
527   unsigned orderIt = block->getNumArguments();
528   for (Operation &op : *block) {
529     if (unsigned numResults = op.getNumResults()) {
530       opOrderIndex.try_emplace(&op, orderIt);
531       orderIt += numResults;
532     }
533     auto opHash = OperationEquivalence::computeHash(
534         &op, OperationEquivalence::ignoreHashValue,
535         OperationEquivalence::ignoreHashValue,
536         OperationEquivalence::IgnoreLocations);
537     hash = llvm::hash_combine(hash, opHash);
538   }
539 }
540 
541 unsigned BlockEquivalenceData::getOrderOf(Value value) const {
542   assert(value.getParentBlock() == block && "expected value of this block");
543 
544   // Arguments use the argument number as the order index.
545   if (BlockArgument arg = dyn_cast<BlockArgument>(value))
546     return arg.getArgNumber();
547 
548   // Otherwise, the result order is offset from the parent op's order.
549   OpResult result = cast<OpResult>(value);
550   auto opOrderIt = opOrderIndex.find(result.getDefiningOp());
551   assert(opOrderIt != opOrderIndex.end() && "expected op to have an order");
552   return opOrderIt->second + result.getResultNumber();
553 }
554 
555 //===----------------------------------------------------------------------===//
556 // BlockMergeCluster
557 
558 namespace {
559 /// This class represents a cluster of blocks to be merged together.
560 class BlockMergeCluster {
561 public:
562   BlockMergeCluster(BlockEquivalenceData &&leaderData)
563       : leaderData(std::move(leaderData)) {}
564 
565   /// Attempt to add the given block to this cluster. Returns success if the
566   /// block was merged, failure otherwise.
567   LogicalResult addToCluster(BlockEquivalenceData &blockData);
568 
569   /// Try to merge all of the blocks within this cluster into the leader block.
570   LogicalResult merge(RewriterBase &rewriter);
571 
572 private:
573   /// The equivalence data for the leader of the cluster.
574   BlockEquivalenceData leaderData;
575 
576   /// The set of blocks that can be merged into the leader.
577   llvm::SmallSetVector<Block *, 1> blocksToMerge;
578 
579   /// A set of operand+index pairs that correspond to operands that need to be
580   /// replaced by arguments when the cluster gets merged.
581   std::set<std::pair<int, int>> operandsToMerge;
582 };
583 } // namespace
584 
585 LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
586   if (leaderData.hash != blockData.hash)
587     return failure();
588   Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block;
589   if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes())
590     return failure();
591 
592   // A set of operands that mismatch between the leader and the new block.
593   SmallVector<std::pair<int, int>, 8> mismatchedOperands;
594   auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end();
595   auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
596   for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
597     // Check that the operations are equivalent.
598     if (!OperationEquivalence::isEquivalentTo(
599             &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
600             /*markEquivalent=*/nullptr,
601             OperationEquivalence::Flags::IgnoreLocations))
602       return failure();
603 
604     // Compare the operands of the two operations. If the operand is within
605     // the block, it must refer to the same operation.
606     auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands();
607     for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) {
608       Value lhsOperand = lhsOperands[operand];
609       Value rhsOperand = rhsOperands[operand];
610       if (lhsOperand == rhsOperand)
611         continue;
612       // Check that the types of the operands match.
613       if (lhsOperand.getType() != rhsOperand.getType())
614         return failure();
615 
616       // Check that these uses are both external, or both internal.
617       bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock;
618       bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock;
619       if (lhsIsInBlock != rhsIsInBlock)
620         return failure();
621       // Let the operands differ if they are defined in a different block. These
622       // will become new arguments if the blocks get merged.
623       if (!lhsIsInBlock) {
624 
625         // Check whether the operands aren't the result of an immediate
626         // predecessors terminator. In that case we are not able to use it as a
627         // successor operand when branching to the merged block as it does not
628         // dominate its producing operation.
629         auto isValidSuccessorArg = [](Block *block, Value operand) {
630           if (operand.getDefiningOp() !=
631               operand.getParentBlock()->getTerminator())
632             return true;
633           return !llvm::is_contained(block->getPredecessors(),
634                                      operand.getParentBlock());
635         };
636 
637         if (!isValidSuccessorArg(leaderBlock, lhsOperand) ||
638             !isValidSuccessorArg(mergeBlock, rhsOperand))
639           return failure();
640 
641         mismatchedOperands.emplace_back(opI, operand);
642         continue;
643       }
644 
645       // Otherwise, these operands must have the same logical order within the
646       // parent block.
647       if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand))
648         return failure();
649     }
650 
651     // If the lhs or rhs has external uses, the blocks cannot be merged as the
652     // merged version of this operation will not be either the lhs or rhs
653     // alone (thus semantically incorrect), but some mix dependending on which
654     // block preceeded this.
655     // TODO allow merging of operations when one block does not dominate the
656     // other
657     if (rhsIt->isUsedOutsideOfBlock(mergeBlock) ||
658         lhsIt->isUsedOutsideOfBlock(leaderBlock)) {
659       return failure();
660     }
661   }
662   // Make sure that the block sizes are equivalent.
663   if (lhsIt != lhsE || rhsIt != rhsE)
664     return failure();
665 
666   // If we get here, the blocks are equivalent and can be merged.
667   operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
668   blocksToMerge.insert(blockData.block);
669   return success();
670 }
671 
672 /// Returns true if the predecessor terminators of the given block can not have
673 /// their operands updated.
674 static bool ableToUpdatePredOperands(Block *block) {
675   for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
676     if (!isa<BranchOpInterface>((*it)->getTerminator()))
677       return false;
678   }
679   return true;
680 }
681 
682 /// Prunes the redundant list of new arguments. E.g., if we are passing an
683 /// argument list like [x, y, z, x] this would return [x, y, z] and it would
684 /// update the `block` (to whom the argument are passed to) accordingly. The new
685 /// arguments are passed as arguments at the back of the block, hence we need to
686 /// know how many `numOldArguments` were before, in order to correctly replace
687 /// the new arguments in the block
688 static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
689     const SmallVector<SmallVector<Value, 8>, 2> &newArguments,
690     RewriterBase &rewriter, unsigned numOldArguments, Block *block) {
691 
692   SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
693       newArguments.size(), SmallVector<Value, 8>());
694 
695   if (newArguments.empty())
696     return newArguments;
697 
698   // `newArguments` is a 2D array of size `numLists` x `numArgs`
699   unsigned numLists = newArguments.size();
700   unsigned numArgs = newArguments[0].size();
701 
702   // Map that for each arg index contains the index that we can use in place of
703   // the original index. E.g., if we have newArgs = [x, y, z, x], we will have
704   // idxToReplacement[3] = 0
705   llvm::DenseMap<unsigned, unsigned> idxToReplacement;
706 
707   // This is a useful data structure to track the first appearance of a Value
708   // on a given list of arguments
709   DenseMap<Value, unsigned> firstValueToIdx;
710   for (unsigned j = 0; j < numArgs; ++j) {
711     Value newArg = newArguments[0][j];
712     firstValueToIdx.try_emplace(newArg, j);
713   }
714 
715   // Go through the first list of arguments (list 0).
716   for (unsigned j = 0; j < numArgs; ++j) {
717     // Look back to see if there are possible redundancies in list 0. Please
718     // note that we are using a map to annotate when an argument was seen first
719     // to avoid a O(N^2) algorithm. This has the drawback that if we have two
720     // lists like:
721     // list0: [%a, %a, %a]
722     // list1: [%c, %b, %b]
723     // We cannot simplify it, because firstValueToIdx[%a] = 0, but we cannot
724     // point list1[1](==%b) or list1[2](==%b) to list1[0](==%c).  However, since
725     // the number of arguments can be potentially unbounded we cannot afford a
726     // O(N^2) algorithm (to search to all the possible pairs) and we need to
727     // accept the trade-off.
728     unsigned k = firstValueToIdx[newArguments[0][j]];
729     if (k == j)
730       continue;
731 
732     bool shouldReplaceJ = true;
733     unsigned replacement = k;
734     // If a possible redundancy is found, then scan the other lists: we
735     // can prune the arguments if and only if they are redundant in every
736     // list.
737     for (unsigned i = 1; i < numLists; ++i)
738       shouldReplaceJ =
739           shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
740     // Save the replacement.
741     if (shouldReplaceJ)
742       idxToReplacement[j] = replacement;
743   }
744 
745   // Populate the pruned argument list.
746   for (unsigned i = 0; i < numLists; ++i)
747     for (unsigned j = 0; j < numArgs; ++j)
748       if (!idxToReplacement.contains(j))
749         newArgumentsPruned[i].push_back(newArguments[i][j]);
750 
751   // Replace the block's redundant arguments.
752   SmallVector<unsigned> toErase;
753   for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
754     if (idxToReplacement.contains(idx)) {
755       Value oldArg = block->getArgument(numOldArguments + idx);
756       Value newArg =
757           block->getArgument(numOldArguments + idxToReplacement[idx]);
758       rewriter.replaceAllUsesWith(oldArg, newArg);
759       toErase.push_back(numOldArguments + idx);
760     }
761   }
762 
763   // Erase the block's redundant arguments.
764   for (unsigned idxToErase : llvm::reverse(toErase))
765     block->eraseArgument(idxToErase);
766   return newArgumentsPruned;
767 }
768 
769 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
770   // Don't consider clusters that don't have blocks to merge.
771   if (blocksToMerge.empty())
772     return failure();
773 
774   Block *leaderBlock = leaderData.block;
775   if (!operandsToMerge.empty()) {
776     // If the cluster has operands to merge, verify that the predecessor
777     // terminators of each of the blocks can have their successor operands
778     // updated.
779     // TODO: We could try and sub-partition this cluster if only some blocks
780     // cause the mismatch.
781     if (!ableToUpdatePredOperands(leaderBlock) ||
782         !llvm::all_of(blocksToMerge, ableToUpdatePredOperands))
783       return failure();
784 
785     // Collect the iterators for each of the blocks to merge. We will walk all
786     // of the iterators at once to avoid operand index invalidation.
787     SmallVector<Block::iterator, 2> blockIterators;
788     blockIterators.reserve(blocksToMerge.size() + 1);
789     blockIterators.push_back(leaderBlock->begin());
790     for (Block *mergeBlock : blocksToMerge)
791       blockIterators.push_back(mergeBlock->begin());
792 
793     // Update each of the predecessor terminators with the new arguments.
794     SmallVector<SmallVector<Value, 8>, 2> newArguments(
795         1 + blocksToMerge.size(),
796         SmallVector<Value, 8>(operandsToMerge.size()));
797     unsigned curOpIndex = 0;
798     unsigned numOldArguments = leaderBlock->getNumArguments();
799     for (const auto &it : llvm::enumerate(operandsToMerge)) {
800       unsigned nextOpOffset = it.value().first - curOpIndex;
801       curOpIndex = it.value().first;
802 
803       // Process the operand for each of the block iterators.
804       for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) {
805         Block::iterator &blockIter = blockIterators[i];
806         std::advance(blockIter, nextOpOffset);
807         auto &operand = blockIter->getOpOperand(it.value().second);
808         newArguments[i][it.index()] = operand.get();
809 
810         // Update the operand and insert an argument if this is the leader.
811         if (i == 0) {
812           Value operandVal = operand.get();
813           operand.set(leaderBlock->addArgument(operandVal.getType(),
814                                                operandVal.getLoc()));
815         }
816       }
817     }
818 
819     // Prune redundant arguments and update the leader block argument list
820     newArguments = pruneRedundantArguments(newArguments, rewriter,
821                                            numOldArguments, leaderBlock);
822 
823     // Update the predecessors for each of the blocks.
824     auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
825       for (auto predIt = block->pred_begin(), predE = block->pred_end();
826            predIt != predE; ++predIt) {
827         auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
828         unsigned succIndex = predIt.getSuccessorIndex();
829         branch.getSuccessorOperands(succIndex).append(
830             newArguments[clusterIndex]);
831       }
832     };
833     updatePredecessors(leaderBlock, /*clusterIndex=*/0);
834     for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i)
835       updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1);
836   }
837 
838   // Replace all uses of the merged blocks with the leader and erase them.
839   for (Block *block : blocksToMerge) {
840     block->replaceAllUsesWith(leaderBlock);
841     rewriter.eraseBlock(block);
842   }
843   return success();
844 }
845 
846 /// Identify identical blocks within the given region and merge them, inserting
847 /// new block arguments as necessary. Returns success if any blocks were merged,
848 /// failure otherwise.
849 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
850                                           Region &region) {
851   if (region.empty() || llvm::hasSingleElement(region))
852     return failure();
853 
854   // Identify sets of blocks, other than the entry block, that branch to the
855   // same successors. We will use these groups to create clusters of equivalent
856   // blocks.
857   DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors;
858   for (Block &block : llvm::drop_begin(region, 1))
859     matchingSuccessors[block.getSuccessors()].push_back(&block);
860 
861   bool mergedAnyBlocks = false;
862   for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) {
863     if (blocks.size() == 1)
864       continue;
865 
866     SmallVector<BlockMergeCluster, 1> clusters;
867     for (Block *block : blocks) {
868       BlockEquivalenceData data(block);
869 
870       // Don't allow merging if this block has any regions.
871       // TODO: Add support for regions if necessary.
872       bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) {
873         return llvm::any_of(op.getRegions(),
874                             [](Region &region) { return !region.empty(); });
875       });
876       if (hasNonEmptyRegion)
877         continue;
878 
879       // Don't allow merging if this block's arguments are used outside of the
880       // original block.
881       bool argHasExternalUsers = llvm::any_of(
882           block->getArguments(), [block](mlir::BlockArgument &arg) {
883             return arg.isUsedOutsideOfBlock(block);
884           });
885       if (argHasExternalUsers)
886         continue;
887 
888       // Try to add this block to an existing cluster.
889       bool addedToCluster = false;
890       for (auto &cluster : clusters)
891         if ((addedToCluster = succeeded(cluster.addToCluster(data))))
892           break;
893       if (!addedToCluster)
894         clusters.emplace_back(std::move(data));
895     }
896     for (auto &cluster : clusters)
897       mergedAnyBlocks |= succeeded(cluster.merge(rewriter));
898   }
899 
900   return success(mergedAnyBlocks);
901 }
902 
903 /// Identify identical blocks within the given regions and merge them, inserting
904 /// new block arguments as necessary.
905 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
906                                           MutableArrayRef<Region> regions) {
907   llvm::SmallSetVector<Region *, 1> worklist;
908   for (auto &region : regions)
909     worklist.insert(&region);
910   bool anyChanged = false;
911   while (!worklist.empty()) {
912     Region *region = worklist.pop_back_val();
913     if (succeeded(mergeIdenticalBlocks(rewriter, *region))) {
914       worklist.insert(region);
915       anyChanged = true;
916     }
917 
918     // Add any nested regions to the worklist.
919     for (Block &block : *region)
920       for (auto &op : block)
921         for (auto &nestedRegion : op.getRegions())
922           worklist.insert(&nestedRegion);
923   }
924 
925   return success(anyChanged);
926 }
927 
928 /// If a block's argument is always the same across different invocations, then
929 /// drop the argument and use the value directly inside the block
930 static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
931                                             Block &block) {
932   SmallVector<size_t> argsToErase;
933 
934   // Go through the arguments of the block.
935   for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) {
936     bool sameArg = true;
937     Value commonValue;
938 
939     // Go through the block predecessor and flag if they pass to the block
940     // different values for the same argument.
941     for (Block::pred_iterator predIt = block.pred_begin(),
942                               predE = block.pred_end();
943          predIt != predE; ++predIt) {
944       auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
945       if (!branch) {
946         sameArg = false;
947         break;
948       }
949       unsigned succIndex = predIt.getSuccessorIndex();
950       SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
951       auto branchOperands = succOperands.getForwardedOperands();
952       if (!commonValue) {
953         commonValue = branchOperands[argIdx];
954         continue;
955       }
956       if (branchOperands[argIdx] != commonValue) {
957         sameArg = false;
958         break;
959       }
960     }
961 
962     // If they are passing the same value, drop the argument.
963     if (commonValue && sameArg) {
964       argsToErase.push_back(argIdx);
965 
966       // Remove the argument from the block.
967       rewriter.replaceAllUsesWith(blockOperand, commonValue);
968     }
969   }
970 
971   // Remove the arguments.
972   for (size_t argIdx : llvm::reverse(argsToErase)) {
973     block.eraseArgument(argIdx);
974 
975     // Remove the argument from the branch ops.
976     for (auto predIt = block.pred_begin(), predE = block.pred_end();
977          predIt != predE; ++predIt) {
978       auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
979       unsigned succIndex = predIt.getSuccessorIndex();
980       SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
981       succOperands.erase(argIdx);
982     }
983   }
984   return success(!argsToErase.empty());
985 }
986 
987 /// This optimization drops redundant argument to blocks. I.e., if a given
988 /// argument to a block receives the same value from each of the block
989 /// predecessors, we can remove the argument from the block and use directly the
990 /// original value. This is a simple example:
991 ///
992 /// %cond = llvm.call @rand() : () -> i1
993 /// %val0 = llvm.mlir.constant(1 : i64) : i64
994 /// %val1 = llvm.mlir.constant(2 : i64) : i64
995 /// %val2 = llvm.mlir.constant(3 : i64) : i64
996 /// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
997 /// : i64)
998 ///
999 /// ^bb1(%arg0 : i64, %arg1 : i64):
1000 ///    llvm.call @foo(%arg0, %arg1)
1001 ///
1002 /// The previous IR can be rewritten as:
1003 /// %cond = llvm.call @rand() : () -> i1
1004 /// %val0 = llvm.mlir.constant(1 : i64) : i64
1005 /// %val1 = llvm.mlir.constant(2 : i64) : i64
1006 /// %val2 = llvm.mlir.constant(3 : i64) : i64
1007 /// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
1008 ///
1009 /// ^bb1(%arg0 : i64):
1010 ///    llvm.call @foo(%val0, %arg0)
1011 ///
1012 static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
1013                                             MutableArrayRef<Region> regions) {
1014   llvm::SmallSetVector<Region *, 1> worklist;
1015   for (Region &region : regions)
1016     worklist.insert(&region);
1017   bool anyChanged = false;
1018   while (!worklist.empty()) {
1019     Region *region = worklist.pop_back_val();
1020 
1021     // Add any nested regions to the worklist.
1022     for (Block &block : *region) {
1023       anyChanged =
1024           succeeded(dropRedundantArguments(rewriter, block)) || anyChanged;
1025 
1026       for (Operation &op : block)
1027         for (Region &nestedRegion : op.getRegions())
1028           worklist.insert(&nestedRegion);
1029     }
1030   }
1031   return success(anyChanged);
1032 }
1033 
1034 //===----------------------------------------------------------------------===//
1035 // Region Simplification
1036 //===----------------------------------------------------------------------===//
1037 
1038 /// Run a set of structural simplifications over the given regions. This
1039 /// includes transformations like unreachable block elimination, dead argument
1040 /// elimination, as well as some other DCE. This function returns success if any
1041 /// of the regions were simplified, failure otherwise.
1042 LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
1043                                     MutableArrayRef<Region> regions,
1044                                     bool mergeBlocks) {
1045   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
1046   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
1047   bool mergedIdenticalBlocks = false;
1048   bool droppedRedundantArguments = false;
1049   if (mergeBlocks) {
1050     mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
1051     droppedRedundantArguments =
1052         succeeded(dropRedundantArguments(rewriter, regions));
1053   }
1054   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
1055                  mergedIdenticalBlocks || droppedRedundantArguments);
1056 }
1057