1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/RegionUtils.h" 10 #include "mlir/Analysis/TopologicalSortUtils.h" 11 #include "mlir/IR/Block.h" 12 #include "mlir/IR/IRMapping.h" 13 #include "mlir/IR/Operation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/RegionGraphTraits.h" 16 #include "mlir/IR/Value.h" 17 #include "mlir/Interfaces/ControlFlowInterfaces.h" 18 #include "mlir/Interfaces/SideEffectInterfaces.h" 19 20 #include "llvm/ADT/DepthFirstIterator.h" 21 #include "llvm/ADT/PostOrderIterator.h" 22 23 #include <deque> 24 25 using namespace mlir; 26 27 void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, 28 Region ®ion) { 29 for (auto &use : llvm::make_early_inc_range(orig.getUses())) { 30 if (region.isAncestor(use.getOwner()->getParentRegion())) 31 use.set(replacement); 32 } 33 } 34 35 void mlir::visitUsedValuesDefinedAbove( 36 Region ®ion, Region &limit, function_ref<void(OpOperand *)> callback) { 37 assert(limit.isAncestor(®ion) && 38 "expected isolation limit to be an ancestor of the given region"); 39 40 // Collect proper ancestors of `limit` upfront to avoid traversing the region 41 // tree for every value. 42 SmallPtrSet<Region *, 4> properAncestors; 43 for (auto *reg = limit.getParentRegion(); reg != nullptr; 44 reg = reg->getParentRegion()) { 45 properAncestors.insert(reg); 46 } 47 48 region.walk([callback, &properAncestors](Operation *op) { 49 for (OpOperand &operand : op->getOpOperands()) 50 // Callback on values defined in a proper ancestor of region. 51 if (properAncestors.count(operand.get().getParentRegion())) 52 callback(&operand); 53 }); 54 } 55 56 void mlir::visitUsedValuesDefinedAbove( 57 MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) { 58 for (Region ®ion : regions) 59 visitUsedValuesDefinedAbove(region, region, callback); 60 } 61 62 void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, 63 SetVector<Value> &values) { 64 visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) { 65 values.insert(operand->get()); 66 }); 67 } 68 69 void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions, 70 SetVector<Value> &values) { 71 for (Region ®ion : regions) 72 getUsedValuesDefinedAbove(region, region, values); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // Make block isolated from above. 77 //===----------------------------------------------------------------------===// 78 79 SmallVector<Value> mlir::makeRegionIsolatedFromAbove( 80 RewriterBase &rewriter, Region ®ion, 81 llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion) { 82 83 // Get initial list of values used within region but defined above. 84 llvm::SetVector<Value> initialCapturedValues; 85 mlir::getUsedValuesDefinedAbove(region, initialCapturedValues); 86 87 std::deque<Value> worklist(initialCapturedValues.begin(), 88 initialCapturedValues.end()); 89 llvm::DenseSet<Value> visited; 90 llvm::DenseSet<Operation *> visitedOps; 91 92 llvm::SetVector<Value> finalCapturedValues; 93 SmallVector<Operation *> clonedOperations; 94 while (!worklist.empty()) { 95 Value currValue = worklist.front(); 96 worklist.pop_front(); 97 if (visited.count(currValue)) 98 continue; 99 visited.insert(currValue); 100 101 Operation *definingOp = currValue.getDefiningOp(); 102 if (!definingOp || visitedOps.count(definingOp)) { 103 finalCapturedValues.insert(currValue); 104 continue; 105 } 106 visitedOps.insert(definingOp); 107 108 if (!cloneOperationIntoRegion(definingOp)) { 109 // Defining operation isnt cloned, so add the current value to final 110 // captured values list. 111 finalCapturedValues.insert(currValue); 112 continue; 113 } 114 115 // Add all operands of the operation to the worklist and mark the op as to 116 // be cloned. 117 for (Value operand : definingOp->getOperands()) { 118 if (visited.count(operand)) 119 continue; 120 worklist.push_back(operand); 121 } 122 clonedOperations.push_back(definingOp); 123 } 124 125 // The operations to be cloned need to be ordered in topological order 126 // so that they can be cloned into the region without violating use-def 127 // chains. 128 mlir::computeTopologicalSorting(clonedOperations); 129 130 OpBuilder::InsertionGuard g(rewriter); 131 // Collect types of existing block 132 Block *entryBlock = ®ion.front(); 133 SmallVector<Type> newArgTypes = 134 llvm::to_vector(entryBlock->getArgumentTypes()); 135 SmallVector<Location> newArgLocs = llvm::to_vector(llvm::map_range( 136 entryBlock->getArguments(), [](BlockArgument b) { return b.getLoc(); })); 137 138 // Append the types of the captured values. 139 for (auto value : finalCapturedValues) { 140 newArgTypes.push_back(value.getType()); 141 newArgLocs.push_back(value.getLoc()); 142 } 143 144 // Create a new entry block. 145 Block *newEntryBlock = 146 rewriter.createBlock(®ion, region.begin(), newArgTypes, newArgLocs); 147 auto newEntryBlockArgs = newEntryBlock->getArguments(); 148 149 // Create a mapping between the captured values and the new arguments added. 150 IRMapping map; 151 auto replaceIfFn = [&](OpOperand &use) { 152 return use.getOwner()->getBlock()->getParent() == ®ion; 153 }; 154 for (auto [arg, capturedVal] : 155 llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()), 156 finalCapturedValues)) { 157 map.map(capturedVal, arg); 158 rewriter.replaceUsesWithIf(capturedVal, arg, replaceIfFn); 159 } 160 rewriter.setInsertionPointToStart(newEntryBlock); 161 for (auto *clonedOp : clonedOperations) { 162 Operation *newOp = rewriter.clone(*clonedOp, map); 163 rewriter.replaceOpUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn); 164 } 165 rewriter.mergeBlocks( 166 entryBlock, newEntryBlock, 167 newEntryBlock->getArguments().take_front(entryBlock->getNumArguments())); 168 return llvm::to_vector(finalCapturedValues); 169 } 170 171 //===----------------------------------------------------------------------===// 172 // Unreachable Block Elimination 173 //===----------------------------------------------------------------------===// 174 175 /// Erase the unreachable blocks within the provided regions. Returns success 176 /// if any blocks were erased, failure otherwise. 177 // TODO: We could likely merge this with the DCE algorithm below. 178 LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter, 179 MutableArrayRef<Region> regions) { 180 // Set of blocks found to be reachable within a given region. 181 llvm::df_iterator_default_set<Block *, 16> reachable; 182 // If any blocks were found to be dead. 183 bool erasedDeadBlocks = false; 184 185 SmallVector<Region *, 1> worklist; 186 worklist.reserve(regions.size()); 187 for (Region ®ion : regions) 188 worklist.push_back(®ion); 189 while (!worklist.empty()) { 190 Region *region = worklist.pop_back_val(); 191 if (region->empty()) 192 continue; 193 194 // If this is a single block region, just collect the nested regions. 195 if (std::next(region->begin()) == region->end()) { 196 for (Operation &op : region->front()) 197 for (Region ®ion : op.getRegions()) 198 worklist.push_back(®ion); 199 continue; 200 } 201 202 // Mark all reachable blocks. 203 reachable.clear(); 204 for (Block *block : depth_first_ext(®ion->front(), reachable)) 205 (void)block /* Mark all reachable blocks */; 206 207 // Collect all of the dead blocks and push the live regions onto the 208 // worklist. 209 for (Block &block : llvm::make_early_inc_range(*region)) { 210 if (!reachable.count(&block)) { 211 block.dropAllDefinedValueUses(); 212 rewriter.eraseBlock(&block); 213 erasedDeadBlocks = true; 214 continue; 215 } 216 217 // Walk any regions within this block. 218 for (Operation &op : block) 219 for (Region ®ion : op.getRegions()) 220 worklist.push_back(®ion); 221 } 222 } 223 224 return success(erasedDeadBlocks); 225 } 226 227 //===----------------------------------------------------------------------===// 228 // Dead Code Elimination 229 //===----------------------------------------------------------------------===// 230 231 namespace { 232 /// Data structure used to track which values have already been proved live. 233 /// 234 /// Because Operation's can have multiple results, this data structure tracks 235 /// liveness for both Value's and Operation's to avoid having to look through 236 /// all Operation results when analyzing a use. 237 /// 238 /// This data structure essentially tracks the dataflow lattice. 239 /// The set of values/ops proved live increases monotonically to a fixed-point. 240 class LiveMap { 241 public: 242 /// Value methods. 243 bool wasProvenLive(Value value) { 244 // TODO: For results that are removable, e.g. for region based control flow, 245 // we could allow for these values to be tracked independently. 246 if (OpResult result = dyn_cast<OpResult>(value)) 247 return wasProvenLive(result.getOwner()); 248 return wasProvenLive(cast<BlockArgument>(value)); 249 } 250 bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); } 251 void setProvedLive(Value value) { 252 // TODO: For results that are removable, e.g. for region based control flow, 253 // we could allow for these values to be tracked independently. 254 if (OpResult result = dyn_cast<OpResult>(value)) 255 return setProvedLive(result.getOwner()); 256 setProvedLive(cast<BlockArgument>(value)); 257 } 258 void setProvedLive(BlockArgument arg) { 259 changed |= liveValues.insert(arg).second; 260 } 261 262 /// Operation methods. 263 bool wasProvenLive(Operation *op) { return liveOps.count(op); } 264 void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; } 265 266 /// Methods for tracking if we have reached a fixed-point. 267 void resetChanged() { changed = false; } 268 bool hasChanged() { return changed; } 269 270 private: 271 bool changed = false; 272 DenseSet<Value> liveValues; 273 DenseSet<Operation *> liveOps; 274 }; 275 } // namespace 276 277 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) { 278 Operation *owner = use.getOwner(); 279 unsigned operandIndex = use.getOperandNumber(); 280 // This pass generally treats all uses of an op as live if the op itself is 281 // considered live. However, for successor operands to terminators we need a 282 // finer-grained notion where we deduce liveness for operands individually. 283 // The reason for this is easiest to think about in terms of a classical phi 284 // node based SSA IR, where each successor operand is really an operand to a 285 // *separate* phi node, rather than all operands to the branch itself as with 286 // the block argument representation that MLIR uses. 287 // 288 // And similarly, because each successor operand is really an operand to a phi 289 // node, rather than to the terminator op itself, a terminator op can't e.g. 290 // "print" the value of a successor operand. 291 if (owner->hasTrait<OpTrait::IsTerminator>()) { 292 if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner)) 293 if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex)) 294 return !liveMap.wasProvenLive(*arg); 295 return false; 296 } 297 return false; 298 } 299 300 static void processValue(Value value, LiveMap &liveMap) { 301 bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) { 302 if (isUseSpeciallyKnownDead(use, liveMap)) 303 return false; 304 return liveMap.wasProvenLive(use.getOwner()); 305 }); 306 if (provedLive) 307 liveMap.setProvedLive(value); 308 } 309 310 static void propagateLiveness(Region ®ion, LiveMap &liveMap); 311 312 static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) { 313 // Terminators are always live. 314 liveMap.setProvedLive(op); 315 316 // Check to see if we can reason about the successor operands and mutate them. 317 BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op); 318 if (!branchInterface) { 319 for (Block *successor : op->getSuccessors()) 320 for (BlockArgument arg : successor->getArguments()) 321 liveMap.setProvedLive(arg); 322 return; 323 } 324 325 // If we can't reason about the operand to a successor, conservatively mark 326 // it as live. 327 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { 328 SuccessorOperands successorOperands = 329 branchInterface.getSuccessorOperands(i); 330 for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount(); 331 opI != opE; ++opI) 332 liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI)); 333 } 334 } 335 336 static void propagateLiveness(Operation *op, LiveMap &liveMap) { 337 // Recurse on any regions the op has. 338 for (Region ®ion : op->getRegions()) 339 propagateLiveness(region, liveMap); 340 341 // Process terminator operations. 342 if (op->hasTrait<OpTrait::IsTerminator>()) 343 return propagateTerminatorLiveness(op, liveMap); 344 345 // Don't reprocess live operations. 346 if (liveMap.wasProvenLive(op)) 347 return; 348 349 // Process the op itself. 350 if (!wouldOpBeTriviallyDead(op)) 351 return liveMap.setProvedLive(op); 352 353 // If the op isn't intrinsically alive, check it's results. 354 for (Value value : op->getResults()) 355 processValue(value, liveMap); 356 } 357 358 static void propagateLiveness(Region ®ion, LiveMap &liveMap) { 359 if (region.empty()) 360 return; 361 362 for (Block *block : llvm::post_order(®ion.front())) { 363 // We process block arguments after the ops in the block, to promote 364 // faster convergence to a fixed point (we try to visit uses before defs). 365 for (Operation &op : llvm::reverse(block->getOperations())) 366 propagateLiveness(&op, liveMap); 367 368 // We currently do not remove entry block arguments, so there is no need to 369 // track their liveness. 370 // TODO: We could track these and enable removing dead operands/arguments 371 // from region control flow operations. 372 if (block->isEntryBlock()) 373 continue; 374 375 for (Value value : block->getArguments()) { 376 if (!liveMap.wasProvenLive(value)) 377 processValue(value, liveMap); 378 } 379 } 380 } 381 382 static void eraseTerminatorSuccessorOperands(Operation *terminator, 383 LiveMap &liveMap) { 384 BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator); 385 if (!branchOp) 386 return; 387 388 for (unsigned succI = 0, succE = terminator->getNumSuccessors(); 389 succI < succE; succI++) { 390 // Iterating successors in reverse is not strictly needed, since we 391 // aren't erasing any successors. But it is slightly more efficient 392 // since it will promote later operands of the terminator being erased 393 // first, reducing the quadratic-ness. 394 unsigned succ = succE - succI - 1; 395 SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ); 396 Block *successor = terminator->getSuccessor(succ); 397 398 for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) { 399 // Iterating args in reverse is needed for correctness, to avoid 400 // shifting later args when earlier args are erased. 401 unsigned arg = argE - argI - 1; 402 if (!liveMap.wasProvenLive(successor->getArgument(arg))) 403 succOperands.erase(arg); 404 } 405 } 406 } 407 408 static LogicalResult deleteDeadness(RewriterBase &rewriter, 409 MutableArrayRef<Region> regions, 410 LiveMap &liveMap) { 411 bool erasedAnything = false; 412 for (Region ®ion : regions) { 413 if (region.empty()) 414 continue; 415 bool hasSingleBlock = llvm::hasSingleElement(region); 416 417 // Delete every operation that is not live. Graph regions may have cycles 418 // in the use-def graph, so we must explicitly dropAllUses() from each 419 // operation as we erase it. Visiting the operations in post-order 420 // guarantees that in SSA CFG regions value uses are removed before defs, 421 // which makes dropAllUses() a no-op. 422 for (Block *block : llvm::post_order(®ion.front())) { 423 if (!hasSingleBlock) 424 eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap); 425 for (Operation &childOp : 426 llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) { 427 if (!liveMap.wasProvenLive(&childOp)) { 428 erasedAnything = true; 429 childOp.dropAllUses(); 430 rewriter.eraseOp(&childOp); 431 } else { 432 erasedAnything |= succeeded( 433 deleteDeadness(rewriter, childOp.getRegions(), liveMap)); 434 } 435 } 436 } 437 // Delete block arguments. 438 // The entry block has an unknown contract with their enclosing block, so 439 // skip it. 440 for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) { 441 block.eraseArguments( 442 [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); }); 443 } 444 } 445 return success(erasedAnything); 446 } 447 448 // This function performs a simple dead code elimination algorithm over the 449 // given regions. 450 // 451 // The overall goal is to prove that Values are dead, which allows deleting ops 452 // and block arguments. 453 // 454 // This uses an optimistic algorithm that assumes everything is dead until 455 // proved otherwise, allowing it to delete recursively dead cycles. 456 // 457 // This is a simple fixed-point dataflow analysis algorithm on a lattice 458 // {Dead,Alive}. Because liveness flows backward, we generally try to 459 // iterate everything backward to speed up convergence to the fixed-point. This 460 // allows for being able to delete recursively dead cycles of the use-def graph, 461 // including block arguments. 462 // 463 // This function returns success if any operations or arguments were deleted, 464 // failure otherwise. 465 LogicalResult mlir::runRegionDCE(RewriterBase &rewriter, 466 MutableArrayRef<Region> regions) { 467 LiveMap liveMap; 468 do { 469 liveMap.resetChanged(); 470 471 for (Region ®ion : regions) 472 propagateLiveness(region, liveMap); 473 } while (liveMap.hasChanged()); 474 475 return deleteDeadness(rewriter, regions, liveMap); 476 } 477 478 //===----------------------------------------------------------------------===// 479 // Block Merging 480 //===----------------------------------------------------------------------===// 481 482 //===----------------------------------------------------------------------===// 483 // BlockEquivalenceData 484 485 namespace { 486 /// This class contains the information for comparing the equivalencies of two 487 /// blocks. Blocks are considered equivalent if they contain the same operations 488 /// in the same order. The only allowed divergence is for operands that come 489 /// from sources outside of the parent block, i.e. the uses of values produced 490 /// within the block must be equivalent. 491 /// e.g., 492 /// Equivalent: 493 /// ^bb1(%arg0: i32) 494 /// return %arg0, %foo : i32, i32 495 /// ^bb2(%arg1: i32) 496 /// return %arg1, %bar : i32, i32 497 /// Not Equivalent: 498 /// ^bb1(%arg0: i32) 499 /// return %foo, %arg0 : i32, i32 500 /// ^bb2(%arg1: i32) 501 /// return %arg1, %bar : i32, i32 502 struct BlockEquivalenceData { 503 BlockEquivalenceData(Block *block); 504 505 /// Return the order index for the given value that is within the block of 506 /// this data. 507 unsigned getOrderOf(Value value) const; 508 509 /// The block this data refers to. 510 Block *block; 511 /// A hash value for this block. 512 llvm::hash_code hash; 513 /// A map of result producing operations to their relative orders within this 514 /// block. The order of an operation is the number of defined values that are 515 /// produced within the block before this operation. 516 DenseMap<Operation *, unsigned> opOrderIndex; 517 }; 518 } // namespace 519 520 BlockEquivalenceData::BlockEquivalenceData(Block *block) 521 : block(block), hash(0) { 522 unsigned orderIt = block->getNumArguments(); 523 for (Operation &op : *block) { 524 if (unsigned numResults = op.getNumResults()) { 525 opOrderIndex.try_emplace(&op, orderIt); 526 orderIt += numResults; 527 } 528 auto opHash = OperationEquivalence::computeHash( 529 &op, OperationEquivalence::ignoreHashValue, 530 OperationEquivalence::ignoreHashValue, 531 OperationEquivalence::IgnoreLocations); 532 hash = llvm::hash_combine(hash, opHash); 533 } 534 } 535 536 unsigned BlockEquivalenceData::getOrderOf(Value value) const { 537 assert(value.getParentBlock() == block && "expected value of this block"); 538 539 // Arguments use the argument number as the order index. 540 if (BlockArgument arg = dyn_cast<BlockArgument>(value)) 541 return arg.getArgNumber(); 542 543 // Otherwise, the result order is offset from the parent op's order. 544 OpResult result = cast<OpResult>(value); 545 auto opOrderIt = opOrderIndex.find(result.getDefiningOp()); 546 assert(opOrderIt != opOrderIndex.end() && "expected op to have an order"); 547 return opOrderIt->second + result.getResultNumber(); 548 } 549 550 //===----------------------------------------------------------------------===// 551 // BlockMergeCluster 552 553 namespace { 554 /// This class represents a cluster of blocks to be merged together. 555 class BlockMergeCluster { 556 public: 557 BlockMergeCluster(BlockEquivalenceData &&leaderData) 558 : leaderData(std::move(leaderData)) {} 559 560 /// Attempt to add the given block to this cluster. Returns success if the 561 /// block was merged, failure otherwise. 562 LogicalResult addToCluster(BlockEquivalenceData &blockData); 563 564 /// Try to merge all of the blocks within this cluster into the leader block. 565 LogicalResult merge(RewriterBase &rewriter); 566 567 private: 568 /// The equivalence data for the leader of the cluster. 569 BlockEquivalenceData leaderData; 570 571 /// The set of blocks that can be merged into the leader. 572 llvm::SmallSetVector<Block *, 1> blocksToMerge; 573 574 /// A set of operand+index pairs that correspond to operands that need to be 575 /// replaced by arguments when the cluster gets merged. 576 std::set<std::pair<int, int>> operandsToMerge; 577 }; 578 } // namespace 579 580 LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) { 581 if (leaderData.hash != blockData.hash) 582 return failure(); 583 Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block; 584 if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes()) 585 return failure(); 586 587 // A set of operands that mismatch between the leader and the new block. 588 SmallVector<std::pair<int, int>, 8> mismatchedOperands; 589 auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end(); 590 auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end(); 591 for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) { 592 // Check that the operations are equivalent. 593 if (!OperationEquivalence::isEquivalentTo( 594 &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence, 595 /*markEquivalent=*/nullptr, 596 OperationEquivalence::Flags::IgnoreLocations)) 597 return failure(); 598 599 // Compare the operands of the two operations. If the operand is within 600 // the block, it must refer to the same operation. 601 auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands(); 602 for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) { 603 Value lhsOperand = lhsOperands[operand]; 604 Value rhsOperand = rhsOperands[operand]; 605 if (lhsOperand == rhsOperand) 606 continue; 607 // Check that the types of the operands match. 608 if (lhsOperand.getType() != rhsOperand.getType()) 609 return failure(); 610 611 // Check that these uses are both external, or both internal. 612 bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock; 613 bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock; 614 if (lhsIsInBlock != rhsIsInBlock) 615 return failure(); 616 // Let the operands differ if they are defined in a different block. These 617 // will become new arguments if the blocks get merged. 618 if (!lhsIsInBlock) { 619 620 // Check whether the operands aren't the result of an immediate 621 // predecessors terminator. In that case we are not able to use it as a 622 // successor operand when branching to the merged block as it does not 623 // dominate its producing operation. 624 auto isValidSuccessorArg = [](Block *block, Value operand) { 625 if (operand.getDefiningOp() != 626 operand.getParentBlock()->getTerminator()) 627 return true; 628 return !llvm::is_contained(block->getPredecessors(), 629 operand.getParentBlock()); 630 }; 631 632 if (!isValidSuccessorArg(leaderBlock, lhsOperand) || 633 !isValidSuccessorArg(mergeBlock, rhsOperand)) 634 return failure(); 635 636 mismatchedOperands.emplace_back(opI, operand); 637 continue; 638 } 639 640 // Otherwise, these operands must have the same logical order within the 641 // parent block. 642 if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand)) 643 return failure(); 644 } 645 646 // If the lhs or rhs has external uses, the blocks cannot be merged as the 647 // merged version of this operation will not be either the lhs or rhs 648 // alone (thus semantically incorrect), but some mix dependending on which 649 // block preceeded this. 650 // TODO allow merging of operations when one block does not dominate the 651 // other 652 if (rhsIt->isUsedOutsideOfBlock(mergeBlock) || 653 lhsIt->isUsedOutsideOfBlock(leaderBlock)) { 654 return failure(); 655 } 656 } 657 // Make sure that the block sizes are equivalent. 658 if (lhsIt != lhsE || rhsIt != rhsE) 659 return failure(); 660 661 // If we get here, the blocks are equivalent and can be merged. 662 operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end()); 663 blocksToMerge.insert(blockData.block); 664 return success(); 665 } 666 667 /// Returns true if the predecessor terminators of the given block can not have 668 /// their operands updated. 669 static bool ableToUpdatePredOperands(Block *block) { 670 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { 671 if (!isa<BranchOpInterface>((*it)->getTerminator())) 672 return false; 673 } 674 return true; 675 } 676 677 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) { 678 // Don't consider clusters that don't have blocks to merge. 679 if (blocksToMerge.empty()) 680 return failure(); 681 682 Block *leaderBlock = leaderData.block; 683 if (!operandsToMerge.empty()) { 684 // If the cluster has operands to merge, verify that the predecessor 685 // terminators of each of the blocks can have their successor operands 686 // updated. 687 // TODO: We could try and sub-partition this cluster if only some blocks 688 // cause the mismatch. 689 if (!ableToUpdatePredOperands(leaderBlock) || 690 !llvm::all_of(blocksToMerge, ableToUpdatePredOperands)) 691 return failure(); 692 693 // Collect the iterators for each of the blocks to merge. We will walk all 694 // of the iterators at once to avoid operand index invalidation. 695 SmallVector<Block::iterator, 2> blockIterators; 696 blockIterators.reserve(blocksToMerge.size() + 1); 697 blockIterators.push_back(leaderBlock->begin()); 698 for (Block *mergeBlock : blocksToMerge) 699 blockIterators.push_back(mergeBlock->begin()); 700 701 // Update each of the predecessor terminators with the new arguments. 702 SmallVector<SmallVector<Value, 8>, 2> newArguments( 703 1 + blocksToMerge.size(), 704 SmallVector<Value, 8>(operandsToMerge.size())); 705 unsigned curOpIndex = 0; 706 for (const auto &it : llvm::enumerate(operandsToMerge)) { 707 unsigned nextOpOffset = it.value().first - curOpIndex; 708 curOpIndex = it.value().first; 709 710 // Process the operand for each of the block iterators. 711 for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) { 712 Block::iterator &blockIter = blockIterators[i]; 713 std::advance(blockIter, nextOpOffset); 714 auto &operand = blockIter->getOpOperand(it.value().second); 715 newArguments[i][it.index()] = operand.get(); 716 717 // Update the operand and insert an argument if this is the leader. 718 if (i == 0) { 719 Value operandVal = operand.get(); 720 operand.set(leaderBlock->addArgument(operandVal.getType(), 721 operandVal.getLoc())); 722 } 723 } 724 } 725 // Update the predecessors for each of the blocks. 726 auto updatePredecessors = [&](Block *block, unsigned clusterIndex) { 727 for (auto predIt = block->pred_begin(), predE = block->pred_end(); 728 predIt != predE; ++predIt) { 729 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator()); 730 unsigned succIndex = predIt.getSuccessorIndex(); 731 branch.getSuccessorOperands(succIndex).append( 732 newArguments[clusterIndex]); 733 } 734 }; 735 updatePredecessors(leaderBlock, /*clusterIndex=*/0); 736 for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i) 737 updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1); 738 } 739 740 // Replace all uses of the merged blocks with the leader and erase them. 741 for (Block *block : blocksToMerge) { 742 block->replaceAllUsesWith(leaderBlock); 743 rewriter.eraseBlock(block); 744 } 745 return success(); 746 } 747 748 /// Identify identical blocks within the given region and merge them, inserting 749 /// new block arguments as necessary. Returns success if any blocks were merged, 750 /// failure otherwise. 751 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, 752 Region ®ion) { 753 if (region.empty() || llvm::hasSingleElement(region)) 754 return failure(); 755 756 // Identify sets of blocks, other than the entry block, that branch to the 757 // same successors. We will use these groups to create clusters of equivalent 758 // blocks. 759 DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors; 760 for (Block &block : llvm::drop_begin(region, 1)) 761 matchingSuccessors[block.getSuccessors()].push_back(&block); 762 763 bool mergedAnyBlocks = false; 764 for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) { 765 if (blocks.size() == 1) 766 continue; 767 768 SmallVector<BlockMergeCluster, 1> clusters; 769 for (Block *block : blocks) { 770 BlockEquivalenceData data(block); 771 772 // Don't allow merging if this block has any regions. 773 // TODO: Add support for regions if necessary. 774 bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) { 775 return llvm::any_of(op.getRegions(), 776 [](Region ®ion) { return !region.empty(); }); 777 }); 778 if (hasNonEmptyRegion) 779 continue; 780 781 // Try to add this block to an existing cluster. 782 bool addedToCluster = false; 783 for (auto &cluster : clusters) 784 if ((addedToCluster = succeeded(cluster.addToCluster(data)))) 785 break; 786 if (!addedToCluster) 787 clusters.emplace_back(std::move(data)); 788 } 789 for (auto &cluster : clusters) 790 mergedAnyBlocks |= succeeded(cluster.merge(rewriter)); 791 } 792 793 return success(mergedAnyBlocks); 794 } 795 796 /// Identify identical blocks within the given regions and merge them, inserting 797 /// new block arguments as necessary. 798 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, 799 MutableArrayRef<Region> regions) { 800 llvm::SmallSetVector<Region *, 1> worklist; 801 for (auto ®ion : regions) 802 worklist.insert(®ion); 803 bool anyChanged = false; 804 while (!worklist.empty()) { 805 Region *region = worklist.pop_back_val(); 806 if (succeeded(mergeIdenticalBlocks(rewriter, *region))) { 807 worklist.insert(region); 808 anyChanged = true; 809 } 810 811 // Add any nested regions to the worklist. 812 for (Block &block : *region) 813 for (auto &op : block) 814 for (auto &nestedRegion : op.getRegions()) 815 worklist.insert(&nestedRegion); 816 } 817 818 return success(anyChanged); 819 } 820 821 //===----------------------------------------------------------------------===// 822 // Region Simplification 823 //===----------------------------------------------------------------------===// 824 825 /// Run a set of structural simplifications over the given regions. This 826 /// includes transformations like unreachable block elimination, dead argument 827 /// elimination, as well as some other DCE. This function returns success if any 828 /// of the regions were simplified, failure otherwise. 829 LogicalResult mlir::simplifyRegions(RewriterBase &rewriter, 830 MutableArrayRef<Region> regions, 831 bool mergeBlocks) { 832 bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions)); 833 bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions)); 834 bool mergedIdenticalBlocks = false; 835 if (mergeBlocks) 836 mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions)); 837 return success(eliminatedBlocks || eliminatedOpsOrArgs || 838 mergedIdenticalBlocks); 839 } 840