1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/RegionUtils.h" 10 #include "mlir/IR/Block.h" 11 #include "mlir/IR/Operation.h" 12 #include "mlir/IR/PatternMatch.h" 13 #include "mlir/IR/RegionGraphTraits.h" 14 #include "mlir/IR/Value.h" 15 #include "mlir/Interfaces/ControlFlowInterfaces.h" 16 #include "mlir/Interfaces/SideEffectInterfaces.h" 17 18 #include "llvm/ADT/DepthFirstIterator.h" 19 #include "llvm/ADT/PostOrderIterator.h" 20 #include "llvm/ADT/SmallSet.h" 21 22 using namespace mlir; 23 24 void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, 25 Region ®ion) { 26 for (auto &use : llvm::make_early_inc_range(orig.getUses())) { 27 if (region.isAncestor(use.getOwner()->getParentRegion())) 28 use.set(replacement); 29 } 30 } 31 32 void mlir::visitUsedValuesDefinedAbove( 33 Region ®ion, Region &limit, function_ref<void(OpOperand *)> callback) { 34 assert(limit.isAncestor(®ion) && 35 "expected isolation limit to be an ancestor of the given region"); 36 37 // Collect proper ancestors of `limit` upfront to avoid traversing the region 38 // tree for every value. 39 SmallPtrSet<Region *, 4> properAncestors; 40 for (auto *reg = limit.getParentRegion(); reg != nullptr; 41 reg = reg->getParentRegion()) { 42 properAncestors.insert(reg); 43 } 44 45 region.walk([callback, &properAncestors](Operation *op) { 46 for (OpOperand &operand : op->getOpOperands()) 47 // Callback on values defined in a proper ancestor of region. 48 if (properAncestors.count(operand.get().getParentRegion())) 49 callback(&operand); 50 }); 51 } 52 53 void mlir::visitUsedValuesDefinedAbove( 54 MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) { 55 for (Region ®ion : regions) 56 visitUsedValuesDefinedAbove(region, region, callback); 57 } 58 59 void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, 60 SetVector<Value> &values) { 61 visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) { 62 values.insert(operand->get()); 63 }); 64 } 65 66 void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions, 67 SetVector<Value> &values) { 68 for (Region ®ion : regions) 69 getUsedValuesDefinedAbove(region, region, values); 70 } 71 72 //===----------------------------------------------------------------------===// 73 // Unreachable Block Elimination 74 //===----------------------------------------------------------------------===// 75 76 /// Erase the unreachable blocks within the provided regions. Returns success 77 /// if any blocks were erased, failure otherwise. 78 // TODO: We could likely merge this with the DCE algorithm below. 79 LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter, 80 MutableArrayRef<Region> regions) { 81 // Set of blocks found to be reachable within a given region. 82 llvm::df_iterator_default_set<Block *, 16> reachable; 83 // If any blocks were found to be dead. 84 bool erasedDeadBlocks = false; 85 86 SmallVector<Region *, 1> worklist; 87 worklist.reserve(regions.size()); 88 for (Region ®ion : regions) 89 worklist.push_back(®ion); 90 while (!worklist.empty()) { 91 Region *region = worklist.pop_back_val(); 92 if (region->empty()) 93 continue; 94 95 // If this is a single block region, just collect the nested regions. 96 if (std::next(region->begin()) == region->end()) { 97 for (Operation &op : region->front()) 98 for (Region ®ion : op.getRegions()) 99 worklist.push_back(®ion); 100 continue; 101 } 102 103 // Mark all reachable blocks. 104 reachable.clear(); 105 for (Block *block : depth_first_ext(®ion->front(), reachable)) 106 (void)block /* Mark all reachable blocks */; 107 108 // Collect all of the dead blocks and push the live regions onto the 109 // worklist. 110 for (Block &block : llvm::make_early_inc_range(*region)) { 111 if (!reachable.count(&block)) { 112 block.dropAllDefinedValueUses(); 113 rewriter.eraseBlock(&block); 114 erasedDeadBlocks = true; 115 continue; 116 } 117 118 // Walk any regions within this block. 119 for (Operation &op : block) 120 for (Region ®ion : op.getRegions()) 121 worklist.push_back(®ion); 122 } 123 } 124 125 return success(erasedDeadBlocks); 126 } 127 128 //===----------------------------------------------------------------------===// 129 // Dead Code Elimination 130 //===----------------------------------------------------------------------===// 131 132 namespace { 133 /// Data structure used to track which values have already been proved live. 134 /// 135 /// Because Operation's can have multiple results, this data structure tracks 136 /// liveness for both Value's and Operation's to avoid having to look through 137 /// all Operation results when analyzing a use. 138 /// 139 /// This data structure essentially tracks the dataflow lattice. 140 /// The set of values/ops proved live increases monotonically to a fixed-point. 141 class LiveMap { 142 public: 143 /// Value methods. 144 bool wasProvenLive(Value value) { 145 // TODO: For results that are removable, e.g. for region based control flow, 146 // we could allow for these values to be tracked independently. 147 if (OpResult result = value.dyn_cast<OpResult>()) 148 return wasProvenLive(result.getOwner()); 149 return wasProvenLive(value.cast<BlockArgument>()); 150 } 151 bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); } 152 void setProvedLive(Value value) { 153 // TODO: For results that are removable, e.g. for region based control flow, 154 // we could allow for these values to be tracked independently. 155 if (OpResult result = value.dyn_cast<OpResult>()) 156 return setProvedLive(result.getOwner()); 157 setProvedLive(value.cast<BlockArgument>()); 158 } 159 void setProvedLive(BlockArgument arg) { 160 changed |= liveValues.insert(arg).second; 161 } 162 163 /// Operation methods. 164 bool wasProvenLive(Operation *op) { return liveOps.count(op); } 165 void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; } 166 167 /// Methods for tracking if we have reached a fixed-point. 168 void resetChanged() { changed = false; } 169 bool hasChanged() { return changed; } 170 171 private: 172 bool changed = false; 173 DenseSet<Value> liveValues; 174 DenseSet<Operation *> liveOps; 175 }; 176 } // namespace 177 178 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) { 179 Operation *owner = use.getOwner(); 180 unsigned operandIndex = use.getOperandNumber(); 181 // This pass generally treats all uses of an op as live if the op itself is 182 // considered live. However, for successor operands to terminators we need a 183 // finer-grained notion where we deduce liveness for operands individually. 184 // The reason for this is easiest to think about in terms of a classical phi 185 // node based SSA IR, where each successor operand is really an operand to a 186 // *separate* phi node, rather than all operands to the branch itself as with 187 // the block argument representation that MLIR uses. 188 // 189 // And similarly, because each successor operand is really an operand to a phi 190 // node, rather than to the terminator op itself, a terminator op can't e.g. 191 // "print" the value of a successor operand. 192 if (owner->hasTrait<OpTrait::IsTerminator>()) { 193 if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner)) 194 if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex)) 195 return !liveMap.wasProvenLive(*arg); 196 return false; 197 } 198 return false; 199 } 200 201 static void processValue(Value value, LiveMap &liveMap) { 202 bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) { 203 if (isUseSpeciallyKnownDead(use, liveMap)) 204 return false; 205 return liveMap.wasProvenLive(use.getOwner()); 206 }); 207 if (provedLive) 208 liveMap.setProvedLive(value); 209 } 210 211 static void propagateLiveness(Region ®ion, LiveMap &liveMap); 212 213 static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) { 214 // Terminators are always live. 215 liveMap.setProvedLive(op); 216 217 // Check to see if we can reason about the successor operands and mutate them. 218 BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op); 219 if (!branchInterface) { 220 for (Block *successor : op->getSuccessors()) 221 for (BlockArgument arg : successor->getArguments()) 222 liveMap.setProvedLive(arg); 223 return; 224 } 225 226 // If we can't reason about the operands to a successor, conservatively mark 227 // all arguments as live. 228 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { 229 if (!branchInterface.getMutableSuccessorOperands(i)) 230 for (BlockArgument arg : op->getSuccessor(i)->getArguments()) 231 liveMap.setProvedLive(arg); 232 } 233 } 234 235 static void propagateLiveness(Operation *op, LiveMap &liveMap) { 236 // Recurse on any regions the op has. 237 for (Region ®ion : op->getRegions()) 238 propagateLiveness(region, liveMap); 239 240 // Process terminator operations. 241 if (op->hasTrait<OpTrait::IsTerminator>()) 242 return propagateTerminatorLiveness(op, liveMap); 243 244 // Don't reprocess live operations. 245 if (liveMap.wasProvenLive(op)) 246 return; 247 248 // Process the op itself. 249 if (!wouldOpBeTriviallyDead(op)) 250 return liveMap.setProvedLive(op); 251 252 // If the op isn't intrinsically alive, check it's results. 253 for (Value value : op->getResults()) 254 processValue(value, liveMap); 255 } 256 257 static void propagateLiveness(Region ®ion, LiveMap &liveMap) { 258 if (region.empty()) 259 return; 260 261 for (Block *block : llvm::post_order(®ion.front())) { 262 // We process block arguments after the ops in the block, to promote 263 // faster convergence to a fixed point (we try to visit uses before defs). 264 for (Operation &op : llvm::reverse(block->getOperations())) 265 propagateLiveness(&op, liveMap); 266 267 // We currently do not remove entry block arguments, so there is no need to 268 // track their liveness. 269 // TODO: We could track these and enable removing dead operands/arguments 270 // from region control flow operations. 271 if (block->isEntryBlock()) 272 continue; 273 274 for (Value value : block->getArguments()) { 275 if (!liveMap.wasProvenLive(value)) 276 processValue(value, liveMap); 277 } 278 } 279 } 280 281 static void eraseTerminatorSuccessorOperands(Operation *terminator, 282 LiveMap &liveMap) { 283 BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator); 284 if (!branchOp) 285 return; 286 287 for (unsigned succI = 0, succE = terminator->getNumSuccessors(); 288 succI < succE; succI++) { 289 // Iterating successors in reverse is not strictly needed, since we 290 // aren't erasing any successors. But it is slightly more efficient 291 // since it will promote later operands of the terminator being erased 292 // first, reducing the quadratic-ness. 293 unsigned succ = succE - succI - 1; 294 Optional<MutableOperandRange> succOperands = 295 branchOp.getMutableSuccessorOperands(succ); 296 if (!succOperands) 297 continue; 298 Block *successor = terminator->getSuccessor(succ); 299 300 for (unsigned argI = 0, argE = succOperands->size(); argI < argE; ++argI) { 301 // Iterating args in reverse is needed for correctness, to avoid 302 // shifting later args when earlier args are erased. 303 unsigned arg = argE - argI - 1; 304 if (!liveMap.wasProvenLive(successor->getArgument(arg))) 305 succOperands->erase(arg); 306 } 307 } 308 } 309 310 static LogicalResult deleteDeadness(RewriterBase &rewriter, 311 MutableArrayRef<Region> regions, 312 LiveMap &liveMap) { 313 bool erasedAnything = false; 314 for (Region ®ion : regions) { 315 if (region.empty()) 316 continue; 317 bool hasSingleBlock = llvm::hasSingleElement(region); 318 319 // Delete every operation that is not live. Graph regions may have cycles 320 // in the use-def graph, so we must explicitly dropAllUses() from each 321 // operation as we erase it. Visiting the operations in post-order 322 // guarantees that in SSA CFG regions value uses are removed before defs, 323 // which makes dropAllUses() a no-op. 324 for (Block *block : llvm::post_order(®ion.front())) { 325 if (!hasSingleBlock) 326 eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap); 327 for (Operation &childOp : 328 llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) { 329 if (!liveMap.wasProvenLive(&childOp)) { 330 erasedAnything = true; 331 childOp.dropAllUses(); 332 rewriter.eraseOp(&childOp); 333 } else { 334 erasedAnything |= succeeded( 335 deleteDeadness(rewriter, childOp.getRegions(), liveMap)); 336 } 337 } 338 } 339 // Delete block arguments. 340 // The entry block has an unknown contract with their enclosing block, so 341 // skip it. 342 for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) { 343 block.eraseArguments( 344 [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); }); 345 } 346 } 347 return success(erasedAnything); 348 } 349 350 // This function performs a simple dead code elimination algorithm over the 351 // given regions. 352 // 353 // The overall goal is to prove that Values are dead, which allows deleting ops 354 // and block arguments. 355 // 356 // This uses an optimistic algorithm that assumes everything is dead until 357 // proved otherwise, allowing it to delete recursively dead cycles. 358 // 359 // This is a simple fixed-point dataflow analysis algorithm on a lattice 360 // {Dead,Alive}. Because liveness flows backward, we generally try to 361 // iterate everything backward to speed up convergence to the fixed-point. This 362 // allows for being able to delete recursively dead cycles of the use-def graph, 363 // including block arguments. 364 // 365 // This function returns success if any operations or arguments were deleted, 366 // failure otherwise. 367 LogicalResult mlir::runRegionDCE(RewriterBase &rewriter, 368 MutableArrayRef<Region> regions) { 369 LiveMap liveMap; 370 do { 371 liveMap.resetChanged(); 372 373 for (Region ®ion : regions) 374 propagateLiveness(region, liveMap); 375 } while (liveMap.hasChanged()); 376 377 return deleteDeadness(rewriter, regions, liveMap); 378 } 379 380 //===----------------------------------------------------------------------===// 381 // Block Merging 382 //===----------------------------------------------------------------------===// 383 384 //===----------------------------------------------------------------------===// 385 // BlockEquivalenceData 386 387 namespace { 388 /// This class contains the information for comparing the equivalencies of two 389 /// blocks. Blocks are considered equivalent if they contain the same operations 390 /// in the same order. The only allowed divergence is for operands that come 391 /// from sources outside of the parent block, i.e. the uses of values produced 392 /// within the block must be equivalent. 393 /// e.g., 394 /// Equivalent: 395 /// ^bb1(%arg0: i32) 396 /// return %arg0, %foo : i32, i32 397 /// ^bb2(%arg1: i32) 398 /// return %arg1, %bar : i32, i32 399 /// Not Equivalent: 400 /// ^bb1(%arg0: i32) 401 /// return %foo, %arg0 : i32, i32 402 /// ^bb2(%arg1: i32) 403 /// return %arg1, %bar : i32, i32 404 struct BlockEquivalenceData { 405 BlockEquivalenceData(Block *block); 406 407 /// Return the order index for the given value that is within the block of 408 /// this data. 409 unsigned getOrderOf(Value value) const; 410 411 /// The block this data refers to. 412 Block *block; 413 /// A hash value for this block. 414 llvm::hash_code hash; 415 /// A map of result producing operations to their relative orders within this 416 /// block. The order of an operation is the number of defined values that are 417 /// produced within the block before this operation. 418 DenseMap<Operation *, unsigned> opOrderIndex; 419 }; 420 } // namespace 421 422 BlockEquivalenceData::BlockEquivalenceData(Block *block) 423 : block(block), hash(0) { 424 unsigned orderIt = block->getNumArguments(); 425 for (Operation &op : *block) { 426 if (unsigned numResults = op.getNumResults()) { 427 opOrderIndex.try_emplace(&op, orderIt); 428 orderIt += numResults; 429 } 430 auto opHash = OperationEquivalence::computeHash( 431 &op, OperationEquivalence::ignoreHashValue, 432 OperationEquivalence::ignoreHashValue, 433 OperationEquivalence::IgnoreLocations); 434 hash = llvm::hash_combine(hash, opHash); 435 } 436 } 437 438 unsigned BlockEquivalenceData::getOrderOf(Value value) const { 439 assert(value.getParentBlock() == block && "expected value of this block"); 440 441 // Arguments use the argument number as the order index. 442 if (BlockArgument arg = value.dyn_cast<BlockArgument>()) 443 return arg.getArgNumber(); 444 445 // Otherwise, the result order is offset from the parent op's order. 446 OpResult result = value.cast<OpResult>(); 447 auto opOrderIt = opOrderIndex.find(result.getDefiningOp()); 448 assert(opOrderIt != opOrderIndex.end() && "expected op to have an order"); 449 return opOrderIt->second + result.getResultNumber(); 450 } 451 452 //===----------------------------------------------------------------------===// 453 // BlockMergeCluster 454 455 namespace { 456 /// This class represents a cluster of blocks to be merged together. 457 class BlockMergeCluster { 458 public: 459 BlockMergeCluster(BlockEquivalenceData &&leaderData) 460 : leaderData(std::move(leaderData)) {} 461 462 /// Attempt to add the given block to this cluster. Returns success if the 463 /// block was merged, failure otherwise. 464 LogicalResult addToCluster(BlockEquivalenceData &blockData); 465 466 /// Try to merge all of the blocks within this cluster into the leader block. 467 LogicalResult merge(RewriterBase &rewriter); 468 469 private: 470 /// The equivalence data for the leader of the cluster. 471 BlockEquivalenceData leaderData; 472 473 /// The set of blocks that can be merged into the leader. 474 llvm::SmallSetVector<Block *, 1> blocksToMerge; 475 476 /// A set of operand+index pairs that correspond to operands that need to be 477 /// replaced by arguments when the cluster gets merged. 478 std::set<std::pair<int, int>> operandsToMerge; 479 }; 480 } // namespace 481 482 LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) { 483 if (leaderData.hash != blockData.hash) 484 return failure(); 485 Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block; 486 if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes()) 487 return failure(); 488 489 // A set of operands that mismatch between the leader and the new block. 490 SmallVector<std::pair<int, int>, 8> mismatchedOperands; 491 auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end(); 492 auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end(); 493 for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) { 494 // Check that the operations are equivalent. 495 if (!OperationEquivalence::isEquivalentTo( 496 &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence, 497 OperationEquivalence::ignoreValueEquivalence, 498 OperationEquivalence::Flags::IgnoreLocations)) 499 return failure(); 500 501 // Compare the operands of the two operations. If the operand is within 502 // the block, it must refer to the same operation. 503 auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands(); 504 for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) { 505 Value lhsOperand = lhsOperands[operand]; 506 Value rhsOperand = rhsOperands[operand]; 507 if (lhsOperand == rhsOperand) 508 continue; 509 // Check that the types of the operands match. 510 if (lhsOperand.getType() != rhsOperand.getType()) 511 return failure(); 512 513 // Check that these uses are both external, or both internal. 514 bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock; 515 bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock; 516 if (lhsIsInBlock != rhsIsInBlock) 517 return failure(); 518 // Let the operands differ if they are defined in a different block. These 519 // will become new arguments if the blocks get merged. 520 if (!lhsIsInBlock) { 521 522 // Check whether the operands aren't the result of an immediate 523 // predecessors terminator. In that case we are not able to use it as a 524 // successor operand when branching to the merged block as it does not 525 // dominate its producing operation. 526 auto isValidSuccessorArg = [](Block *block, Value operand) { 527 if (operand.getDefiningOp() != 528 operand.getParentBlock()->getTerminator()) 529 return true; 530 return !llvm::is_contained(block->getPredecessors(), 531 operand.getParentBlock()); 532 }; 533 534 if (!isValidSuccessorArg(leaderBlock, lhsOperand) || 535 !isValidSuccessorArg(mergeBlock, rhsOperand)) 536 return failure(); 537 538 mismatchedOperands.emplace_back(opI, operand); 539 continue; 540 } 541 542 // Otherwise, these operands must have the same logical order within the 543 // parent block. 544 if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand)) 545 return failure(); 546 } 547 548 // If the lhs or rhs has external uses, the blocks cannot be merged as the 549 // merged version of this operation will not be either the lhs or rhs 550 // alone (thus semantically incorrect), but some mix dependending on which 551 // block preceeded this. 552 // TODO allow merging of operations when one block does not dominate the 553 // other 554 if (rhsIt->isUsedOutsideOfBlock(mergeBlock) || 555 lhsIt->isUsedOutsideOfBlock(leaderBlock)) { 556 return failure(); 557 } 558 } 559 // Make sure that the block sizes are equivalent. 560 if (lhsIt != lhsE || rhsIt != rhsE) 561 return failure(); 562 563 // If we get here, the blocks are equivalent and can be merged. 564 operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end()); 565 blocksToMerge.insert(blockData.block); 566 return success(); 567 } 568 569 /// Returns true if the predecessor terminators of the given block can not have 570 /// their operands updated. 571 static bool ableToUpdatePredOperands(Block *block) { 572 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { 573 auto branch = dyn_cast<BranchOpInterface>((*it)->getTerminator()); 574 if (!branch || !branch.getMutableSuccessorOperands(it.getSuccessorIndex())) 575 return false; 576 } 577 return true; 578 } 579 580 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) { 581 // Don't consider clusters that don't have blocks to merge. 582 if (blocksToMerge.empty()) 583 return failure(); 584 585 Block *leaderBlock = leaderData.block; 586 if (!operandsToMerge.empty()) { 587 // If the cluster has operands to merge, verify that the predecessor 588 // terminators of each of the blocks can have their successor operands 589 // updated. 590 // TODO: We could try and sub-partition this cluster if only some blocks 591 // cause the mismatch. 592 if (!ableToUpdatePredOperands(leaderBlock) || 593 !llvm::all_of(blocksToMerge, ableToUpdatePredOperands)) 594 return failure(); 595 596 // Collect the iterators for each of the blocks to merge. We will walk all 597 // of the iterators at once to avoid operand index invalidation. 598 SmallVector<Block::iterator, 2> blockIterators; 599 blockIterators.reserve(blocksToMerge.size() + 1); 600 blockIterators.push_back(leaderBlock->begin()); 601 for (Block *mergeBlock : blocksToMerge) 602 blockIterators.push_back(mergeBlock->begin()); 603 604 // Update each of the predecessor terminators with the new arguments. 605 SmallVector<SmallVector<Value, 8>, 2> newArguments( 606 1 + blocksToMerge.size(), 607 SmallVector<Value, 8>(operandsToMerge.size())); 608 unsigned curOpIndex = 0; 609 for (const auto &it : llvm::enumerate(operandsToMerge)) { 610 unsigned nextOpOffset = it.value().first - curOpIndex; 611 curOpIndex = it.value().first; 612 613 // Process the operand for each of the block iterators. 614 for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) { 615 Block::iterator &blockIter = blockIterators[i]; 616 std::advance(blockIter, nextOpOffset); 617 auto &operand = blockIter->getOpOperand(it.value().second); 618 newArguments[i][it.index()] = operand.get(); 619 620 // Update the operand and insert an argument if this is the leader. 621 if (i == 0) { 622 Value operandVal = operand.get(); 623 operand.set(leaderBlock->addArgument(operandVal.getType(), 624 operandVal.getLoc())); 625 } 626 } 627 } 628 // Update the predecessors for each of the blocks. 629 auto updatePredecessors = [&](Block *block, unsigned clusterIndex) { 630 for (auto predIt = block->pred_begin(), predE = block->pred_end(); 631 predIt != predE; ++predIt) { 632 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator()); 633 unsigned succIndex = predIt.getSuccessorIndex(); 634 branch.getMutableSuccessorOperands(succIndex)->append( 635 newArguments[clusterIndex]); 636 } 637 }; 638 updatePredecessors(leaderBlock, /*clusterIndex=*/0); 639 for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i) 640 updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1); 641 } 642 643 // Replace all uses of the merged blocks with the leader and erase them. 644 for (Block *block : blocksToMerge) { 645 block->replaceAllUsesWith(leaderBlock); 646 rewriter.eraseBlock(block); 647 } 648 return success(); 649 } 650 651 /// Identify identical blocks within the given region and merge them, inserting 652 /// new block arguments as necessary. Returns success if any blocks were merged, 653 /// failure otherwise. 654 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, 655 Region ®ion) { 656 if (region.empty() || llvm::hasSingleElement(region)) 657 return failure(); 658 659 // Identify sets of blocks, other than the entry block, that branch to the 660 // same successors. We will use these groups to create clusters of equivalent 661 // blocks. 662 DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors; 663 for (Block &block : llvm::drop_begin(region, 1)) 664 matchingSuccessors[block.getSuccessors()].push_back(&block); 665 666 bool mergedAnyBlocks = false; 667 for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) { 668 if (blocks.size() == 1) 669 continue; 670 671 SmallVector<BlockMergeCluster, 1> clusters; 672 for (Block *block : blocks) { 673 BlockEquivalenceData data(block); 674 675 // Don't allow merging if this block has any regions. 676 // TODO: Add support for regions if necessary. 677 bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) { 678 return llvm::any_of(op.getRegions(), 679 [](Region ®ion) { return !region.empty(); }); 680 }); 681 if (hasNonEmptyRegion) 682 continue; 683 684 // Try to add this block to an existing cluster. 685 bool addedToCluster = false; 686 for (auto &cluster : clusters) 687 if ((addedToCluster = succeeded(cluster.addToCluster(data)))) 688 break; 689 if (!addedToCluster) 690 clusters.emplace_back(std::move(data)); 691 } 692 for (auto &cluster : clusters) 693 mergedAnyBlocks |= succeeded(cluster.merge(rewriter)); 694 } 695 696 return success(mergedAnyBlocks); 697 } 698 699 /// Identify identical blocks within the given regions and merge them, inserting 700 /// new block arguments as necessary. 701 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, 702 MutableArrayRef<Region> regions) { 703 llvm::SmallSetVector<Region *, 1> worklist; 704 for (auto ®ion : regions) 705 worklist.insert(®ion); 706 bool anyChanged = false; 707 while (!worklist.empty()) { 708 Region *region = worklist.pop_back_val(); 709 if (succeeded(mergeIdenticalBlocks(rewriter, *region))) { 710 worklist.insert(region); 711 anyChanged = true; 712 } 713 714 // Add any nested regions to the worklist. 715 for (Block &block : *region) 716 for (auto &op : block) 717 for (auto &nestedRegion : op.getRegions()) 718 worklist.insert(&nestedRegion); 719 } 720 721 return success(anyChanged); 722 } 723 724 //===----------------------------------------------------------------------===// 725 // Region Simplification 726 //===----------------------------------------------------------------------===// 727 728 /// Run a set of structural simplifications over the given regions. This 729 /// includes transformations like unreachable block elimination, dead argument 730 /// elimination, as well as some other DCE. This function returns success if any 731 /// of the regions were simplified, failure otherwise. 732 LogicalResult mlir::simplifyRegions(RewriterBase &rewriter, 733 MutableArrayRef<Region> regions) { 734 bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions)); 735 bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions)); 736 bool mergedIdenticalBlocks = 737 succeeded(mergeIdenticalBlocks(rewriter, regions)); 738 return success(eliminatedBlocks || eliminatedOpsOrArgs || 739 mergedIdenticalBlocks); 740 } 741