1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/RegionUtils.h" 10 #include "mlir/IR/Block.h" 11 #include "mlir/IR/IRMapping.h" 12 #include "mlir/IR/Operation.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/IR/RegionGraphTraits.h" 15 #include "mlir/IR/Value.h" 16 #include "mlir/Interfaces/ControlFlowInterfaces.h" 17 #include "mlir/Interfaces/SideEffectInterfaces.h" 18 #include "mlir/Transforms/TopologicalSortUtils.h" 19 20 #include "llvm/ADT/DepthFirstIterator.h" 21 #include "llvm/ADT/PostOrderIterator.h" 22 #include "llvm/ADT/SmallSet.h" 23 24 #include <deque> 25 26 using namespace mlir; 27 28 void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, 29 Region ®ion) { 30 for (auto &use : llvm::make_early_inc_range(orig.getUses())) { 31 if (region.isAncestor(use.getOwner()->getParentRegion())) 32 use.set(replacement); 33 } 34 } 35 36 void mlir::visitUsedValuesDefinedAbove( 37 Region ®ion, Region &limit, function_ref<void(OpOperand *)> callback) { 38 assert(limit.isAncestor(®ion) && 39 "expected isolation limit to be an ancestor of the given region"); 40 41 // Collect proper ancestors of `limit` upfront to avoid traversing the region 42 // tree for every value. 43 SmallPtrSet<Region *, 4> properAncestors; 44 for (auto *reg = limit.getParentRegion(); reg != nullptr; 45 reg = reg->getParentRegion()) { 46 properAncestors.insert(reg); 47 } 48 49 region.walk([callback, &properAncestors](Operation *op) { 50 for (OpOperand &operand : op->getOpOperands()) 51 // Callback on values defined in a proper ancestor of region. 52 if (properAncestors.count(operand.get().getParentRegion())) 53 callback(&operand); 54 }); 55 } 56 57 void mlir::visitUsedValuesDefinedAbove( 58 MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) { 59 for (Region ®ion : regions) 60 visitUsedValuesDefinedAbove(region, region, callback); 61 } 62 63 void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, 64 SetVector<Value> &values) { 65 visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) { 66 values.insert(operand->get()); 67 }); 68 } 69 70 void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions, 71 SetVector<Value> &values) { 72 for (Region ®ion : regions) 73 getUsedValuesDefinedAbove(region, region, values); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // Make block isolated from above. 78 //===----------------------------------------------------------------------===// 79 80 SmallVector<Value> mlir::makeRegionIsolatedFromAbove( 81 RewriterBase &rewriter, Region ®ion, 82 llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion) { 83 84 // Get initial list of values used within region but defined above. 85 llvm::SetVector<Value> initialCapturedValues; 86 mlir::getUsedValuesDefinedAbove(region, initialCapturedValues); 87 88 std::deque<Value> worklist(initialCapturedValues.begin(), 89 initialCapturedValues.end()); 90 llvm::DenseSet<Value> visited; 91 llvm::DenseSet<Operation *> visitedOps; 92 93 llvm::SetVector<Value> finalCapturedValues; 94 SmallVector<Operation *> clonedOperations; 95 while (!worklist.empty()) { 96 Value currValue = worklist.front(); 97 worklist.pop_front(); 98 if (visited.count(currValue)) 99 continue; 100 visited.insert(currValue); 101 102 Operation *definingOp = currValue.getDefiningOp(); 103 if (!definingOp || visitedOps.count(definingOp)) { 104 finalCapturedValues.insert(currValue); 105 continue; 106 } 107 visitedOps.insert(definingOp); 108 109 if (!cloneOperationIntoRegion(definingOp)) { 110 // Defining operation isnt cloned, so add the current value to final 111 // captured values list. 112 finalCapturedValues.insert(currValue); 113 continue; 114 } 115 116 // Add all operands of the operation to the worklist and mark the op as to 117 // be cloned. 118 for (Value operand : definingOp->getOperands()) { 119 if (visited.count(operand)) 120 continue; 121 worklist.push_back(operand); 122 } 123 clonedOperations.push_back(definingOp); 124 } 125 126 // The operations to be cloned need to be ordered in topological order 127 // so that they can be cloned into the region without violating use-def 128 // chains. 129 mlir::computeTopologicalSorting(clonedOperations); 130 131 OpBuilder::InsertionGuard g(rewriter); 132 // Collect types of existing block 133 Block *entryBlock = ®ion.front(); 134 SmallVector<Type> newArgTypes = 135 llvm::to_vector(entryBlock->getArgumentTypes()); 136 SmallVector<Location> newArgLocs = llvm::to_vector(llvm::map_range( 137 entryBlock->getArguments(), [](BlockArgument b) { return b.getLoc(); })); 138 139 // Append the types of the captured values. 140 for (auto value : finalCapturedValues) { 141 newArgTypes.push_back(value.getType()); 142 newArgLocs.push_back(value.getLoc()); 143 } 144 145 // Create a new entry block. 146 Block *newEntryBlock = 147 rewriter.createBlock(®ion, region.begin(), newArgTypes, newArgLocs); 148 auto newEntryBlockArgs = newEntryBlock->getArguments(); 149 150 // Create a mapping between the captured values and the new arguments added. 151 IRMapping map; 152 auto replaceIfFn = [&](OpOperand &use) { 153 return use.getOwner()->getBlock()->getParent() == ®ion; 154 }; 155 for (auto [arg, capturedVal] : 156 llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()), 157 finalCapturedValues)) { 158 map.map(capturedVal, arg); 159 rewriter.replaceUsesWithIf(capturedVal, arg, replaceIfFn); 160 } 161 rewriter.setInsertionPointToStart(newEntryBlock); 162 for (auto clonedOp : clonedOperations) { 163 Operation *newOp = rewriter.clone(*clonedOp, map); 164 rewriter.replaceOpWithIf(clonedOp, newOp->getResults(), replaceIfFn); 165 } 166 rewriter.mergeBlocks( 167 entryBlock, newEntryBlock, 168 newEntryBlock->getArguments().take_front(entryBlock->getNumArguments())); 169 return llvm::to_vector(finalCapturedValues); 170 } 171 172 //===----------------------------------------------------------------------===// 173 // Unreachable Block Elimination 174 //===----------------------------------------------------------------------===// 175 176 /// Erase the unreachable blocks within the provided regions. Returns success 177 /// if any blocks were erased, failure otherwise. 178 // TODO: We could likely merge this with the DCE algorithm below. 179 LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter, 180 MutableArrayRef<Region> regions) { 181 // Set of blocks found to be reachable within a given region. 182 llvm::df_iterator_default_set<Block *, 16> reachable; 183 // If any blocks were found to be dead. 184 bool erasedDeadBlocks = false; 185 186 SmallVector<Region *, 1> worklist; 187 worklist.reserve(regions.size()); 188 for (Region ®ion : regions) 189 worklist.push_back(®ion); 190 while (!worklist.empty()) { 191 Region *region = worklist.pop_back_val(); 192 if (region->empty()) 193 continue; 194 195 // If this is a single block region, just collect the nested regions. 196 if (std::next(region->begin()) == region->end()) { 197 for (Operation &op : region->front()) 198 for (Region ®ion : op.getRegions()) 199 worklist.push_back(®ion); 200 continue; 201 } 202 203 // Mark all reachable blocks. 204 reachable.clear(); 205 for (Block *block : depth_first_ext(®ion->front(), reachable)) 206 (void)block /* Mark all reachable blocks */; 207 208 // Collect all of the dead blocks and push the live regions onto the 209 // worklist. 210 for (Block &block : llvm::make_early_inc_range(*region)) { 211 if (!reachable.count(&block)) { 212 block.dropAllDefinedValueUses(); 213 rewriter.eraseBlock(&block); 214 erasedDeadBlocks = true; 215 continue; 216 } 217 218 // Walk any regions within this block. 219 for (Operation &op : block) 220 for (Region ®ion : op.getRegions()) 221 worklist.push_back(®ion); 222 } 223 } 224 225 return success(erasedDeadBlocks); 226 } 227 228 //===----------------------------------------------------------------------===// 229 // Dead Code Elimination 230 //===----------------------------------------------------------------------===// 231 232 namespace { 233 /// Data structure used to track which values have already been proved live. 234 /// 235 /// Because Operation's can have multiple results, this data structure tracks 236 /// liveness for both Value's and Operation's to avoid having to look through 237 /// all Operation results when analyzing a use. 238 /// 239 /// This data structure essentially tracks the dataflow lattice. 240 /// The set of values/ops proved live increases monotonically to a fixed-point. 241 class LiveMap { 242 public: 243 /// Value methods. 244 bool wasProvenLive(Value value) { 245 // TODO: For results that are removable, e.g. for region based control flow, 246 // we could allow for these values to be tracked independently. 247 if (OpResult result = dyn_cast<OpResult>(value)) 248 return wasProvenLive(result.getOwner()); 249 return wasProvenLive(cast<BlockArgument>(value)); 250 } 251 bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); } 252 void setProvedLive(Value value) { 253 // TODO: For results that are removable, e.g. for region based control flow, 254 // we could allow for these values to be tracked independently. 255 if (OpResult result = dyn_cast<OpResult>(value)) 256 return setProvedLive(result.getOwner()); 257 setProvedLive(cast<BlockArgument>(value)); 258 } 259 void setProvedLive(BlockArgument arg) { 260 changed |= liveValues.insert(arg).second; 261 } 262 263 /// Operation methods. 264 bool wasProvenLive(Operation *op) { return liveOps.count(op); } 265 void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; } 266 267 /// Methods for tracking if we have reached a fixed-point. 268 void resetChanged() { changed = false; } 269 bool hasChanged() { return changed; } 270 271 private: 272 bool changed = false; 273 DenseSet<Value> liveValues; 274 DenseSet<Operation *> liveOps; 275 }; 276 } // namespace 277 278 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) { 279 Operation *owner = use.getOwner(); 280 unsigned operandIndex = use.getOperandNumber(); 281 // This pass generally treats all uses of an op as live if the op itself is 282 // considered live. However, for successor operands to terminators we need a 283 // finer-grained notion where we deduce liveness for operands individually. 284 // The reason for this is easiest to think about in terms of a classical phi 285 // node based SSA IR, where each successor operand is really an operand to a 286 // *separate* phi node, rather than all operands to the branch itself as with 287 // the block argument representation that MLIR uses. 288 // 289 // And similarly, because each successor operand is really an operand to a phi 290 // node, rather than to the terminator op itself, a terminator op can't e.g. 291 // "print" the value of a successor operand. 292 if (owner->hasTrait<OpTrait::IsTerminator>()) { 293 if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner)) 294 if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex)) 295 return !liveMap.wasProvenLive(*arg); 296 return false; 297 } 298 return false; 299 } 300 301 static void processValue(Value value, LiveMap &liveMap) { 302 bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) { 303 if (isUseSpeciallyKnownDead(use, liveMap)) 304 return false; 305 return liveMap.wasProvenLive(use.getOwner()); 306 }); 307 if (provedLive) 308 liveMap.setProvedLive(value); 309 } 310 311 static void propagateLiveness(Region ®ion, LiveMap &liveMap); 312 313 static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) { 314 // Terminators are always live. 315 liveMap.setProvedLive(op); 316 317 // Check to see if we can reason about the successor operands and mutate them. 318 BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op); 319 if (!branchInterface) { 320 for (Block *successor : op->getSuccessors()) 321 for (BlockArgument arg : successor->getArguments()) 322 liveMap.setProvedLive(arg); 323 return; 324 } 325 326 // If we can't reason about the operand to a successor, conservatively mark 327 // it as live. 328 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { 329 SuccessorOperands successorOperands = 330 branchInterface.getSuccessorOperands(i); 331 for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount(); 332 opI != opE; ++opI) 333 liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI)); 334 } 335 } 336 337 static void propagateLiveness(Operation *op, LiveMap &liveMap) { 338 // Recurse on any regions the op has. 339 for (Region ®ion : op->getRegions()) 340 propagateLiveness(region, liveMap); 341 342 // Process terminator operations. 343 if (op->hasTrait<OpTrait::IsTerminator>()) 344 return propagateTerminatorLiveness(op, liveMap); 345 346 // Don't reprocess live operations. 347 if (liveMap.wasProvenLive(op)) 348 return; 349 350 // Process the op itself. 351 if (!wouldOpBeTriviallyDead(op)) 352 return liveMap.setProvedLive(op); 353 354 // If the op isn't intrinsically alive, check it's results. 355 for (Value value : op->getResults()) 356 processValue(value, liveMap); 357 } 358 359 static void propagateLiveness(Region ®ion, LiveMap &liveMap) { 360 if (region.empty()) 361 return; 362 363 for (Block *block : llvm::post_order(®ion.front())) { 364 // We process block arguments after the ops in the block, to promote 365 // faster convergence to a fixed point (we try to visit uses before defs). 366 for (Operation &op : llvm::reverse(block->getOperations())) 367 propagateLiveness(&op, liveMap); 368 369 // We currently do not remove entry block arguments, so there is no need to 370 // track their liveness. 371 // TODO: We could track these and enable removing dead operands/arguments 372 // from region control flow operations. 373 if (block->isEntryBlock()) 374 continue; 375 376 for (Value value : block->getArguments()) { 377 if (!liveMap.wasProvenLive(value)) 378 processValue(value, liveMap); 379 } 380 } 381 } 382 383 static void eraseTerminatorSuccessorOperands(Operation *terminator, 384 LiveMap &liveMap) { 385 BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator); 386 if (!branchOp) 387 return; 388 389 for (unsigned succI = 0, succE = terminator->getNumSuccessors(); 390 succI < succE; succI++) { 391 // Iterating successors in reverse is not strictly needed, since we 392 // aren't erasing any successors. But it is slightly more efficient 393 // since it will promote later operands of the terminator being erased 394 // first, reducing the quadratic-ness. 395 unsigned succ = succE - succI - 1; 396 SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ); 397 Block *successor = terminator->getSuccessor(succ); 398 399 for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) { 400 // Iterating args in reverse is needed for correctness, to avoid 401 // shifting later args when earlier args are erased. 402 unsigned arg = argE - argI - 1; 403 if (!liveMap.wasProvenLive(successor->getArgument(arg))) 404 succOperands.erase(arg); 405 } 406 } 407 } 408 409 static LogicalResult deleteDeadness(RewriterBase &rewriter, 410 MutableArrayRef<Region> regions, 411 LiveMap &liveMap) { 412 bool erasedAnything = false; 413 for (Region ®ion : regions) { 414 if (region.empty()) 415 continue; 416 bool hasSingleBlock = llvm::hasSingleElement(region); 417 418 // Delete every operation that is not live. Graph regions may have cycles 419 // in the use-def graph, so we must explicitly dropAllUses() from each 420 // operation as we erase it. Visiting the operations in post-order 421 // guarantees that in SSA CFG regions value uses are removed before defs, 422 // which makes dropAllUses() a no-op. 423 for (Block *block : llvm::post_order(®ion.front())) { 424 if (!hasSingleBlock) 425 eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap); 426 for (Operation &childOp : 427 llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) { 428 if (!liveMap.wasProvenLive(&childOp)) { 429 erasedAnything = true; 430 childOp.dropAllUses(); 431 rewriter.eraseOp(&childOp); 432 } else { 433 erasedAnything |= succeeded( 434 deleteDeadness(rewriter, childOp.getRegions(), liveMap)); 435 } 436 } 437 } 438 // Delete block arguments. 439 // The entry block has an unknown contract with their enclosing block, so 440 // skip it. 441 for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) { 442 block.eraseArguments( 443 [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); }); 444 } 445 } 446 return success(erasedAnything); 447 } 448 449 // This function performs a simple dead code elimination algorithm over the 450 // given regions. 451 // 452 // The overall goal is to prove that Values are dead, which allows deleting ops 453 // and block arguments. 454 // 455 // This uses an optimistic algorithm that assumes everything is dead until 456 // proved otherwise, allowing it to delete recursively dead cycles. 457 // 458 // This is a simple fixed-point dataflow analysis algorithm on a lattice 459 // {Dead,Alive}. Because liveness flows backward, we generally try to 460 // iterate everything backward to speed up convergence to the fixed-point. This 461 // allows for being able to delete recursively dead cycles of the use-def graph, 462 // including block arguments. 463 // 464 // This function returns success if any operations or arguments were deleted, 465 // failure otherwise. 466 LogicalResult mlir::runRegionDCE(RewriterBase &rewriter, 467 MutableArrayRef<Region> regions) { 468 LiveMap liveMap; 469 do { 470 liveMap.resetChanged(); 471 472 for (Region ®ion : regions) 473 propagateLiveness(region, liveMap); 474 } while (liveMap.hasChanged()); 475 476 return deleteDeadness(rewriter, regions, liveMap); 477 } 478 479 //===----------------------------------------------------------------------===// 480 // Block Merging 481 //===----------------------------------------------------------------------===// 482 483 //===----------------------------------------------------------------------===// 484 // BlockEquivalenceData 485 486 namespace { 487 /// This class contains the information for comparing the equivalencies of two 488 /// blocks. Blocks are considered equivalent if they contain the same operations 489 /// in the same order. The only allowed divergence is for operands that come 490 /// from sources outside of the parent block, i.e. the uses of values produced 491 /// within the block must be equivalent. 492 /// e.g., 493 /// Equivalent: 494 /// ^bb1(%arg0: i32) 495 /// return %arg0, %foo : i32, i32 496 /// ^bb2(%arg1: i32) 497 /// return %arg1, %bar : i32, i32 498 /// Not Equivalent: 499 /// ^bb1(%arg0: i32) 500 /// return %foo, %arg0 : i32, i32 501 /// ^bb2(%arg1: i32) 502 /// return %arg1, %bar : i32, i32 503 struct BlockEquivalenceData { 504 BlockEquivalenceData(Block *block); 505 506 /// Return the order index for the given value that is within the block of 507 /// this data. 508 unsigned getOrderOf(Value value) const; 509 510 /// The block this data refers to. 511 Block *block; 512 /// A hash value for this block. 513 llvm::hash_code hash; 514 /// A map of result producing operations to their relative orders within this 515 /// block. The order of an operation is the number of defined values that are 516 /// produced within the block before this operation. 517 DenseMap<Operation *, unsigned> opOrderIndex; 518 }; 519 } // namespace 520 521 BlockEquivalenceData::BlockEquivalenceData(Block *block) 522 : block(block), hash(0) { 523 unsigned orderIt = block->getNumArguments(); 524 for (Operation &op : *block) { 525 if (unsigned numResults = op.getNumResults()) { 526 opOrderIndex.try_emplace(&op, orderIt); 527 orderIt += numResults; 528 } 529 auto opHash = OperationEquivalence::computeHash( 530 &op, OperationEquivalence::ignoreHashValue, 531 OperationEquivalence::ignoreHashValue, 532 OperationEquivalence::IgnoreLocations); 533 hash = llvm::hash_combine(hash, opHash); 534 } 535 } 536 537 unsigned BlockEquivalenceData::getOrderOf(Value value) const { 538 assert(value.getParentBlock() == block && "expected value of this block"); 539 540 // Arguments use the argument number as the order index. 541 if (BlockArgument arg = dyn_cast<BlockArgument>(value)) 542 return arg.getArgNumber(); 543 544 // Otherwise, the result order is offset from the parent op's order. 545 OpResult result = cast<OpResult>(value); 546 auto opOrderIt = opOrderIndex.find(result.getDefiningOp()); 547 assert(opOrderIt != opOrderIndex.end() && "expected op to have an order"); 548 return opOrderIt->second + result.getResultNumber(); 549 } 550 551 //===----------------------------------------------------------------------===// 552 // BlockMergeCluster 553 554 namespace { 555 /// This class represents a cluster of blocks to be merged together. 556 class BlockMergeCluster { 557 public: 558 BlockMergeCluster(BlockEquivalenceData &&leaderData) 559 : leaderData(std::move(leaderData)) {} 560 561 /// Attempt to add the given block to this cluster. Returns success if the 562 /// block was merged, failure otherwise. 563 LogicalResult addToCluster(BlockEquivalenceData &blockData); 564 565 /// Try to merge all of the blocks within this cluster into the leader block. 566 LogicalResult merge(RewriterBase &rewriter); 567 568 private: 569 /// The equivalence data for the leader of the cluster. 570 BlockEquivalenceData leaderData; 571 572 /// The set of blocks that can be merged into the leader. 573 llvm::SmallSetVector<Block *, 1> blocksToMerge; 574 575 /// A set of operand+index pairs that correspond to operands that need to be 576 /// replaced by arguments when the cluster gets merged. 577 std::set<std::pair<int, int>> operandsToMerge; 578 }; 579 } // namespace 580 581 LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) { 582 if (leaderData.hash != blockData.hash) 583 return failure(); 584 Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block; 585 if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes()) 586 return failure(); 587 588 // A set of operands that mismatch between the leader and the new block. 589 SmallVector<std::pair<int, int>, 8> mismatchedOperands; 590 auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end(); 591 auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end(); 592 for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) { 593 // Check that the operations are equivalent. 594 if (!OperationEquivalence::isEquivalentTo( 595 &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence, 596 /*markEquivalent=*/nullptr, 597 OperationEquivalence::Flags::IgnoreLocations)) 598 return failure(); 599 600 // Compare the operands of the two operations. If the operand is within 601 // the block, it must refer to the same operation. 602 auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands(); 603 for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) { 604 Value lhsOperand = lhsOperands[operand]; 605 Value rhsOperand = rhsOperands[operand]; 606 if (lhsOperand == rhsOperand) 607 continue; 608 // Check that the types of the operands match. 609 if (lhsOperand.getType() != rhsOperand.getType()) 610 return failure(); 611 612 // Check that these uses are both external, or both internal. 613 bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock; 614 bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock; 615 if (lhsIsInBlock != rhsIsInBlock) 616 return failure(); 617 // Let the operands differ if they are defined in a different block. These 618 // will become new arguments if the blocks get merged. 619 if (!lhsIsInBlock) { 620 621 // Check whether the operands aren't the result of an immediate 622 // predecessors terminator. In that case we are not able to use it as a 623 // successor operand when branching to the merged block as it does not 624 // dominate its producing operation. 625 auto isValidSuccessorArg = [](Block *block, Value operand) { 626 if (operand.getDefiningOp() != 627 operand.getParentBlock()->getTerminator()) 628 return true; 629 return !llvm::is_contained(block->getPredecessors(), 630 operand.getParentBlock()); 631 }; 632 633 if (!isValidSuccessorArg(leaderBlock, lhsOperand) || 634 !isValidSuccessorArg(mergeBlock, rhsOperand)) 635 return failure(); 636 637 mismatchedOperands.emplace_back(opI, operand); 638 continue; 639 } 640 641 // Otherwise, these operands must have the same logical order within the 642 // parent block. 643 if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand)) 644 return failure(); 645 } 646 647 // If the lhs or rhs has external uses, the blocks cannot be merged as the 648 // merged version of this operation will not be either the lhs or rhs 649 // alone (thus semantically incorrect), but some mix dependending on which 650 // block preceeded this. 651 // TODO allow merging of operations when one block does not dominate the 652 // other 653 if (rhsIt->isUsedOutsideOfBlock(mergeBlock) || 654 lhsIt->isUsedOutsideOfBlock(leaderBlock)) { 655 return failure(); 656 } 657 } 658 // Make sure that the block sizes are equivalent. 659 if (lhsIt != lhsE || rhsIt != rhsE) 660 return failure(); 661 662 // If we get here, the blocks are equivalent and can be merged. 663 operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end()); 664 blocksToMerge.insert(blockData.block); 665 return success(); 666 } 667 668 /// Returns true if the predecessor terminators of the given block can not have 669 /// their operands updated. 670 static bool ableToUpdatePredOperands(Block *block) { 671 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { 672 if (!isa<BranchOpInterface>((*it)->getTerminator())) 673 return false; 674 } 675 return true; 676 } 677 678 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) { 679 // Don't consider clusters that don't have blocks to merge. 680 if (blocksToMerge.empty()) 681 return failure(); 682 683 Block *leaderBlock = leaderData.block; 684 if (!operandsToMerge.empty()) { 685 // If the cluster has operands to merge, verify that the predecessor 686 // terminators of each of the blocks can have their successor operands 687 // updated. 688 // TODO: We could try and sub-partition this cluster if only some blocks 689 // cause the mismatch. 690 if (!ableToUpdatePredOperands(leaderBlock) || 691 !llvm::all_of(blocksToMerge, ableToUpdatePredOperands)) 692 return failure(); 693 694 // Collect the iterators for each of the blocks to merge. We will walk all 695 // of the iterators at once to avoid operand index invalidation. 696 SmallVector<Block::iterator, 2> blockIterators; 697 blockIterators.reserve(blocksToMerge.size() + 1); 698 blockIterators.push_back(leaderBlock->begin()); 699 for (Block *mergeBlock : blocksToMerge) 700 blockIterators.push_back(mergeBlock->begin()); 701 702 // Update each of the predecessor terminators with the new arguments. 703 SmallVector<SmallVector<Value, 8>, 2> newArguments( 704 1 + blocksToMerge.size(), 705 SmallVector<Value, 8>(operandsToMerge.size())); 706 unsigned curOpIndex = 0; 707 for (const auto &it : llvm::enumerate(operandsToMerge)) { 708 unsigned nextOpOffset = it.value().first - curOpIndex; 709 curOpIndex = it.value().first; 710 711 // Process the operand for each of the block iterators. 712 for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) { 713 Block::iterator &blockIter = blockIterators[i]; 714 std::advance(blockIter, nextOpOffset); 715 auto &operand = blockIter->getOpOperand(it.value().second); 716 newArguments[i][it.index()] = operand.get(); 717 718 // Update the operand and insert an argument if this is the leader. 719 if (i == 0) { 720 Value operandVal = operand.get(); 721 operand.set(leaderBlock->addArgument(operandVal.getType(), 722 operandVal.getLoc())); 723 } 724 } 725 } 726 // Update the predecessors for each of the blocks. 727 auto updatePredecessors = [&](Block *block, unsigned clusterIndex) { 728 for (auto predIt = block->pred_begin(), predE = block->pred_end(); 729 predIt != predE; ++predIt) { 730 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator()); 731 unsigned succIndex = predIt.getSuccessorIndex(); 732 branch.getSuccessorOperands(succIndex).append( 733 newArguments[clusterIndex]); 734 } 735 }; 736 updatePredecessors(leaderBlock, /*clusterIndex=*/0); 737 for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i) 738 updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1); 739 } 740 741 // Replace all uses of the merged blocks with the leader and erase them. 742 for (Block *block : blocksToMerge) { 743 block->replaceAllUsesWith(leaderBlock); 744 rewriter.eraseBlock(block); 745 } 746 return success(); 747 } 748 749 /// Identify identical blocks within the given region and merge them, inserting 750 /// new block arguments as necessary. Returns success if any blocks were merged, 751 /// failure otherwise. 752 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, 753 Region ®ion) { 754 if (region.empty() || llvm::hasSingleElement(region)) 755 return failure(); 756 757 // Identify sets of blocks, other than the entry block, that branch to the 758 // same successors. We will use these groups to create clusters of equivalent 759 // blocks. 760 DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors; 761 for (Block &block : llvm::drop_begin(region, 1)) 762 matchingSuccessors[block.getSuccessors()].push_back(&block); 763 764 bool mergedAnyBlocks = false; 765 for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) { 766 if (blocks.size() == 1) 767 continue; 768 769 SmallVector<BlockMergeCluster, 1> clusters; 770 for (Block *block : blocks) { 771 BlockEquivalenceData data(block); 772 773 // Don't allow merging if this block has any regions. 774 // TODO: Add support for regions if necessary. 775 bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) { 776 return llvm::any_of(op.getRegions(), 777 [](Region ®ion) { return !region.empty(); }); 778 }); 779 if (hasNonEmptyRegion) 780 continue; 781 782 // Try to add this block to an existing cluster. 783 bool addedToCluster = false; 784 for (auto &cluster : clusters) 785 if ((addedToCluster = succeeded(cluster.addToCluster(data)))) 786 break; 787 if (!addedToCluster) 788 clusters.emplace_back(std::move(data)); 789 } 790 for (auto &cluster : clusters) 791 mergedAnyBlocks |= succeeded(cluster.merge(rewriter)); 792 } 793 794 return success(mergedAnyBlocks); 795 } 796 797 /// Identify identical blocks within the given regions and merge them, inserting 798 /// new block arguments as necessary. 799 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, 800 MutableArrayRef<Region> regions) { 801 llvm::SmallSetVector<Region *, 1> worklist; 802 for (auto ®ion : regions) 803 worklist.insert(®ion); 804 bool anyChanged = false; 805 while (!worklist.empty()) { 806 Region *region = worklist.pop_back_val(); 807 if (succeeded(mergeIdenticalBlocks(rewriter, *region))) { 808 worklist.insert(region); 809 anyChanged = true; 810 } 811 812 // Add any nested regions to the worklist. 813 for (Block &block : *region) 814 for (auto &op : block) 815 for (auto &nestedRegion : op.getRegions()) 816 worklist.insert(&nestedRegion); 817 } 818 819 return success(anyChanged); 820 } 821 822 //===----------------------------------------------------------------------===// 823 // Region Simplification 824 //===----------------------------------------------------------------------===// 825 826 /// Run a set of structural simplifications over the given regions. This 827 /// includes transformations like unreachable block elimination, dead argument 828 /// elimination, as well as some other DCE. This function returns success if any 829 /// of the regions were simplified, failure otherwise. 830 LogicalResult mlir::simplifyRegions(RewriterBase &rewriter, 831 MutableArrayRef<Region> regions) { 832 bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions)); 833 bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions)); 834 bool mergedIdenticalBlocks = 835 succeeded(mergeIdenticalBlocks(rewriter, regions)); 836 return success(eliminatedBlocks || eliminatedOpsOrArgs || 837 mergedIdenticalBlocks); 838 } 839 840 SetVector<Block *> mlir::getTopologicallySortedBlocks(Region ®ion) { 841 // For each block that has not been visited yet (i.e. that has no 842 // predecessors), add it to the list as well as its successors. 843 SetVector<Block *> blocks; 844 for (Block &b : region) { 845 if (blocks.count(&b) == 0) { 846 llvm::ReversePostOrderTraversal<Block *> traversal(&b); 847 blocks.insert(traversal.begin(), traversal.end()); 848 } 849 } 850 assert(blocks.size() == region.getBlocks().size() && 851 "some blocks are not sorted"); 852 853 return blocks; 854 } 855