1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/RegionUtils.h" 10 #include "mlir/Analysis/TopologicalSortUtils.h" 11 #include "mlir/IR/Block.h" 12 #include "mlir/IR/BuiltinOps.h" 13 #include "mlir/IR/IRMapping.h" 14 #include "mlir/IR/Operation.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/RegionGraphTraits.h" 17 #include "mlir/IR/Value.h" 18 #include "mlir/Interfaces/ControlFlowInterfaces.h" 19 #include "mlir/Interfaces/SideEffectInterfaces.h" 20 #include "mlir/Support/LogicalResult.h" 21 22 #include "llvm/ADT/DepthFirstIterator.h" 23 #include "llvm/ADT/PostOrderIterator.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ADT/SmallSet.h" 26 27 #include <deque> 28 #include <iterator> 29 30 using namespace mlir; 31 32 void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, 33 Region ®ion) { 34 for (auto &use : llvm::make_early_inc_range(orig.getUses())) { 35 if (region.isAncestor(use.getOwner()->getParentRegion())) 36 use.set(replacement); 37 } 38 } 39 40 void mlir::visitUsedValuesDefinedAbove( 41 Region ®ion, Region &limit, function_ref<void(OpOperand *)> callback) { 42 assert(limit.isAncestor(®ion) && 43 "expected isolation limit to be an ancestor of the given region"); 44 45 // Collect proper ancestors of `limit` upfront to avoid traversing the region 46 // tree for every value. 47 SmallPtrSet<Region *, 4> properAncestors; 48 for (auto *reg = limit.getParentRegion(); reg != nullptr; 49 reg = reg->getParentRegion()) { 50 properAncestors.insert(reg); 51 } 52 53 region.walk([callback, &properAncestors](Operation *op) { 54 for (OpOperand &operand : op->getOpOperands()) 55 // Callback on values defined in a proper ancestor of region. 56 if (properAncestors.count(operand.get().getParentRegion())) 57 callback(&operand); 58 }); 59 } 60 61 void mlir::visitUsedValuesDefinedAbove( 62 MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) { 63 for (Region ®ion : regions) 64 visitUsedValuesDefinedAbove(region, region, callback); 65 } 66 67 void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, 68 SetVector<Value> &values) { 69 visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) { 70 values.insert(operand->get()); 71 }); 72 } 73 74 void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions, 75 SetVector<Value> &values) { 76 for (Region ®ion : regions) 77 getUsedValuesDefinedAbove(region, region, values); 78 } 79 80 //===----------------------------------------------------------------------===// 81 // Make block isolated from above. 82 //===----------------------------------------------------------------------===// 83 84 SmallVector<Value> mlir::makeRegionIsolatedFromAbove( 85 RewriterBase &rewriter, Region ®ion, 86 llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion) { 87 88 // Get initial list of values used within region but defined above. 89 llvm::SetVector<Value> initialCapturedValues; 90 mlir::getUsedValuesDefinedAbove(region, initialCapturedValues); 91 92 std::deque<Value> worklist(initialCapturedValues.begin(), 93 initialCapturedValues.end()); 94 llvm::DenseSet<Value> visited; 95 llvm::DenseSet<Operation *> visitedOps; 96 97 llvm::SetVector<Value> finalCapturedValues; 98 SmallVector<Operation *> clonedOperations; 99 while (!worklist.empty()) { 100 Value currValue = worklist.front(); 101 worklist.pop_front(); 102 if (visited.count(currValue)) 103 continue; 104 visited.insert(currValue); 105 106 Operation *definingOp = currValue.getDefiningOp(); 107 if (!definingOp || visitedOps.count(definingOp)) { 108 finalCapturedValues.insert(currValue); 109 continue; 110 } 111 visitedOps.insert(definingOp); 112 113 if (!cloneOperationIntoRegion(definingOp)) { 114 // Defining operation isnt cloned, so add the current value to final 115 // captured values list. 116 finalCapturedValues.insert(currValue); 117 continue; 118 } 119 120 // Add all operands of the operation to the worklist and mark the op as to 121 // be cloned. 122 for (Value operand : definingOp->getOperands()) { 123 if (visited.count(operand)) 124 continue; 125 worklist.push_back(operand); 126 } 127 clonedOperations.push_back(definingOp); 128 } 129 130 // The operations to be cloned need to be ordered in topological order 131 // so that they can be cloned into the region without violating use-def 132 // chains. 133 mlir::computeTopologicalSorting(clonedOperations); 134 135 OpBuilder::InsertionGuard g(rewriter); 136 // Collect types of existing block 137 Block *entryBlock = ®ion.front(); 138 SmallVector<Type> newArgTypes = 139 llvm::to_vector(entryBlock->getArgumentTypes()); 140 SmallVector<Location> newArgLocs = llvm::to_vector(llvm::map_range( 141 entryBlock->getArguments(), [](BlockArgument b) { return b.getLoc(); })); 142 143 // Append the types of the captured values. 144 for (auto value : finalCapturedValues) { 145 newArgTypes.push_back(value.getType()); 146 newArgLocs.push_back(value.getLoc()); 147 } 148 149 // Create a new entry block. 150 Block *newEntryBlock = 151 rewriter.createBlock(®ion, region.begin(), newArgTypes, newArgLocs); 152 auto newEntryBlockArgs = newEntryBlock->getArguments(); 153 154 // Create a mapping between the captured values and the new arguments added. 155 IRMapping map; 156 auto replaceIfFn = [&](OpOperand &use) { 157 return use.getOwner()->getBlock()->getParent() == ®ion; 158 }; 159 for (auto [arg, capturedVal] : 160 llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()), 161 finalCapturedValues)) { 162 map.map(capturedVal, arg); 163 rewriter.replaceUsesWithIf(capturedVal, arg, replaceIfFn); 164 } 165 rewriter.setInsertionPointToStart(newEntryBlock); 166 for (auto *clonedOp : clonedOperations) { 167 Operation *newOp = rewriter.clone(*clonedOp, map); 168 rewriter.replaceOpUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn); 169 } 170 rewriter.mergeBlocks( 171 entryBlock, newEntryBlock, 172 newEntryBlock->getArguments().take_front(entryBlock->getNumArguments())); 173 return llvm::to_vector(finalCapturedValues); 174 } 175 176 //===----------------------------------------------------------------------===// 177 // Unreachable Block Elimination 178 //===----------------------------------------------------------------------===// 179 180 /// Erase the unreachable blocks within the provided regions. Returns success 181 /// if any blocks were erased, failure otherwise. 182 // TODO: We could likely merge this with the DCE algorithm below. 183 LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter, 184 MutableArrayRef<Region> regions) { 185 // Set of blocks found to be reachable within a given region. 186 llvm::df_iterator_default_set<Block *, 16> reachable; 187 // If any blocks were found to be dead. 188 bool erasedDeadBlocks = false; 189 190 SmallVector<Region *, 1> worklist; 191 worklist.reserve(regions.size()); 192 for (Region ®ion : regions) 193 worklist.push_back(®ion); 194 while (!worklist.empty()) { 195 Region *region = worklist.pop_back_val(); 196 if (region->empty()) 197 continue; 198 199 // If this is a single block region, just collect the nested regions. 200 if (std::next(region->begin()) == region->end()) { 201 for (Operation &op : region->front()) 202 for (Region ®ion : op.getRegions()) 203 worklist.push_back(®ion); 204 continue; 205 } 206 207 // Mark all reachable blocks. 208 reachable.clear(); 209 for (Block *block : depth_first_ext(®ion->front(), reachable)) 210 (void)block /* Mark all reachable blocks */; 211 212 // Collect all of the dead blocks and push the live regions onto the 213 // worklist. 214 for (Block &block : llvm::make_early_inc_range(*region)) { 215 if (!reachable.count(&block)) { 216 block.dropAllDefinedValueUses(); 217 rewriter.eraseBlock(&block); 218 erasedDeadBlocks = true; 219 continue; 220 } 221 222 // Walk any regions within this block. 223 for (Operation &op : block) 224 for (Region ®ion : op.getRegions()) 225 worklist.push_back(®ion); 226 } 227 } 228 229 return success(erasedDeadBlocks); 230 } 231 232 //===----------------------------------------------------------------------===// 233 // Dead Code Elimination 234 //===----------------------------------------------------------------------===// 235 236 namespace { 237 /// Data structure used to track which values have already been proved live. 238 /// 239 /// Because Operation's can have multiple results, this data structure tracks 240 /// liveness for both Value's and Operation's to avoid having to look through 241 /// all Operation results when analyzing a use. 242 /// 243 /// This data structure essentially tracks the dataflow lattice. 244 /// The set of values/ops proved live increases monotonically to a fixed-point. 245 class LiveMap { 246 public: 247 /// Value methods. 248 bool wasProvenLive(Value value) { 249 // TODO: For results that are removable, e.g. for region based control flow, 250 // we could allow for these values to be tracked independently. 251 if (OpResult result = dyn_cast<OpResult>(value)) 252 return wasProvenLive(result.getOwner()); 253 return wasProvenLive(cast<BlockArgument>(value)); 254 } 255 bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); } 256 void setProvedLive(Value value) { 257 // TODO: For results that are removable, e.g. for region based control flow, 258 // we could allow for these values to be tracked independently. 259 if (OpResult result = dyn_cast<OpResult>(value)) 260 return setProvedLive(result.getOwner()); 261 setProvedLive(cast<BlockArgument>(value)); 262 } 263 void setProvedLive(BlockArgument arg) { 264 changed |= liveValues.insert(arg).second; 265 } 266 267 /// Operation methods. 268 bool wasProvenLive(Operation *op) { return liveOps.count(op); } 269 void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; } 270 271 /// Methods for tracking if we have reached a fixed-point. 272 void resetChanged() { changed = false; } 273 bool hasChanged() { return changed; } 274 275 private: 276 bool changed = false; 277 DenseSet<Value> liveValues; 278 DenseSet<Operation *> liveOps; 279 }; 280 } // namespace 281 282 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) { 283 Operation *owner = use.getOwner(); 284 unsigned operandIndex = use.getOperandNumber(); 285 // This pass generally treats all uses of an op as live if the op itself is 286 // considered live. However, for successor operands to terminators we need a 287 // finer-grained notion where we deduce liveness for operands individually. 288 // The reason for this is easiest to think about in terms of a classical phi 289 // node based SSA IR, where each successor operand is really an operand to a 290 // *separate* phi node, rather than all operands to the branch itself as with 291 // the block argument representation that MLIR uses. 292 // 293 // And similarly, because each successor operand is really an operand to a phi 294 // node, rather than to the terminator op itself, a terminator op can't e.g. 295 // "print" the value of a successor operand. 296 if (owner->hasTrait<OpTrait::IsTerminator>()) { 297 if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner)) 298 if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex)) 299 return !liveMap.wasProvenLive(*arg); 300 return false; 301 } 302 return false; 303 } 304 305 static void processValue(Value value, LiveMap &liveMap) { 306 bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) { 307 if (isUseSpeciallyKnownDead(use, liveMap)) 308 return false; 309 return liveMap.wasProvenLive(use.getOwner()); 310 }); 311 if (provedLive) 312 liveMap.setProvedLive(value); 313 } 314 315 static void propagateLiveness(Region ®ion, LiveMap &liveMap); 316 317 static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) { 318 // Terminators are always live. 319 liveMap.setProvedLive(op); 320 321 // Check to see if we can reason about the successor operands and mutate them. 322 BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op); 323 if (!branchInterface) { 324 for (Block *successor : op->getSuccessors()) 325 for (BlockArgument arg : successor->getArguments()) 326 liveMap.setProvedLive(arg); 327 return; 328 } 329 330 // If we can't reason about the operand to a successor, conservatively mark 331 // it as live. 332 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { 333 SuccessorOperands successorOperands = 334 branchInterface.getSuccessorOperands(i); 335 for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount(); 336 opI != opE; ++opI) 337 liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI)); 338 } 339 } 340 341 static void propagateLiveness(Operation *op, LiveMap &liveMap) { 342 // Recurse on any regions the op has. 343 for (Region ®ion : op->getRegions()) 344 propagateLiveness(region, liveMap); 345 346 // Process terminator operations. 347 if (op->hasTrait<OpTrait::IsTerminator>()) 348 return propagateTerminatorLiveness(op, liveMap); 349 350 // Don't reprocess live operations. 351 if (liveMap.wasProvenLive(op)) 352 return; 353 354 // Process the op itself. 355 if (!wouldOpBeTriviallyDead(op)) 356 return liveMap.setProvedLive(op); 357 358 // If the op isn't intrinsically alive, check it's results. 359 for (Value value : op->getResults()) 360 processValue(value, liveMap); 361 } 362 363 static void propagateLiveness(Region ®ion, LiveMap &liveMap) { 364 if (region.empty()) 365 return; 366 367 for (Block *block : llvm::post_order(®ion.front())) { 368 // We process block arguments after the ops in the block, to promote 369 // faster convergence to a fixed point (we try to visit uses before defs). 370 for (Operation &op : llvm::reverse(block->getOperations())) 371 propagateLiveness(&op, liveMap); 372 373 // We currently do not remove entry block arguments, so there is no need to 374 // track their liveness. 375 // TODO: We could track these and enable removing dead operands/arguments 376 // from region control flow operations. 377 if (block->isEntryBlock()) 378 continue; 379 380 for (Value value : block->getArguments()) { 381 if (!liveMap.wasProvenLive(value)) 382 processValue(value, liveMap); 383 } 384 } 385 } 386 387 static void eraseTerminatorSuccessorOperands(Operation *terminator, 388 LiveMap &liveMap) { 389 BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator); 390 if (!branchOp) 391 return; 392 393 for (unsigned succI = 0, succE = terminator->getNumSuccessors(); 394 succI < succE; succI++) { 395 // Iterating successors in reverse is not strictly needed, since we 396 // aren't erasing any successors. But it is slightly more efficient 397 // since it will promote later operands of the terminator being erased 398 // first, reducing the quadratic-ness. 399 unsigned succ = succE - succI - 1; 400 SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ); 401 Block *successor = terminator->getSuccessor(succ); 402 403 for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) { 404 // Iterating args in reverse is needed for correctness, to avoid 405 // shifting later args when earlier args are erased. 406 unsigned arg = argE - argI - 1; 407 if (!liveMap.wasProvenLive(successor->getArgument(arg))) 408 succOperands.erase(arg); 409 } 410 } 411 } 412 413 static LogicalResult deleteDeadness(RewriterBase &rewriter, 414 MutableArrayRef<Region> regions, 415 LiveMap &liveMap) { 416 bool erasedAnything = false; 417 for (Region ®ion : regions) { 418 if (region.empty()) 419 continue; 420 bool hasSingleBlock = llvm::hasSingleElement(region); 421 422 // Delete every operation that is not live. Graph regions may have cycles 423 // in the use-def graph, so we must explicitly dropAllUses() from each 424 // operation as we erase it. Visiting the operations in post-order 425 // guarantees that in SSA CFG regions value uses are removed before defs, 426 // which makes dropAllUses() a no-op. 427 for (Block *block : llvm::post_order(®ion.front())) { 428 if (!hasSingleBlock) 429 eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap); 430 for (Operation &childOp : 431 llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) { 432 if (!liveMap.wasProvenLive(&childOp)) { 433 erasedAnything = true; 434 childOp.dropAllUses(); 435 rewriter.eraseOp(&childOp); 436 } else { 437 erasedAnything |= succeeded( 438 deleteDeadness(rewriter, childOp.getRegions(), liveMap)); 439 } 440 } 441 } 442 // Delete block arguments. 443 // The entry block has an unknown contract with their enclosing block, so 444 // skip it. 445 for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) { 446 block.eraseArguments( 447 [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); }); 448 } 449 } 450 return success(erasedAnything); 451 } 452 453 // This function performs a simple dead code elimination algorithm over the 454 // given regions. 455 // 456 // The overall goal is to prove that Values are dead, which allows deleting ops 457 // and block arguments. 458 // 459 // This uses an optimistic algorithm that assumes everything is dead until 460 // proved otherwise, allowing it to delete recursively dead cycles. 461 // 462 // This is a simple fixed-point dataflow analysis algorithm on a lattice 463 // {Dead,Alive}. Because liveness flows backward, we generally try to 464 // iterate everything backward to speed up convergence to the fixed-point. This 465 // allows for being able to delete recursively dead cycles of the use-def graph, 466 // including block arguments. 467 // 468 // This function returns success if any operations or arguments were deleted, 469 // failure otherwise. 470 LogicalResult mlir::runRegionDCE(RewriterBase &rewriter, 471 MutableArrayRef<Region> regions) { 472 LiveMap liveMap; 473 do { 474 liveMap.resetChanged(); 475 476 for (Region ®ion : regions) 477 propagateLiveness(region, liveMap); 478 } while (liveMap.hasChanged()); 479 480 return deleteDeadness(rewriter, regions, liveMap); 481 } 482 483 //===----------------------------------------------------------------------===// 484 // Block Merging 485 //===----------------------------------------------------------------------===// 486 487 //===----------------------------------------------------------------------===// 488 // BlockEquivalenceData 489 490 namespace { 491 /// This class contains the information for comparing the equivalencies of two 492 /// blocks. Blocks are considered equivalent if they contain the same operations 493 /// in the same order. The only allowed divergence is for operands that come 494 /// from sources outside of the parent block, i.e. the uses of values produced 495 /// within the block must be equivalent. 496 /// e.g., 497 /// Equivalent: 498 /// ^bb1(%arg0: i32) 499 /// return %arg0, %foo : i32, i32 500 /// ^bb2(%arg1: i32) 501 /// return %arg1, %bar : i32, i32 502 /// Not Equivalent: 503 /// ^bb1(%arg0: i32) 504 /// return %foo, %arg0 : i32, i32 505 /// ^bb2(%arg1: i32) 506 /// return %arg1, %bar : i32, i32 507 struct BlockEquivalenceData { 508 BlockEquivalenceData(Block *block); 509 510 /// Return the order index for the given value that is within the block of 511 /// this data. 512 unsigned getOrderOf(Value value) const; 513 514 /// The block this data refers to. 515 Block *block; 516 /// A hash value for this block. 517 llvm::hash_code hash; 518 /// A map of result producing operations to their relative orders within this 519 /// block. The order of an operation is the number of defined values that are 520 /// produced within the block before this operation. 521 DenseMap<Operation *, unsigned> opOrderIndex; 522 }; 523 } // namespace 524 525 BlockEquivalenceData::BlockEquivalenceData(Block *block) 526 : block(block), hash(0) { 527 unsigned orderIt = block->getNumArguments(); 528 for (Operation &op : *block) { 529 if (unsigned numResults = op.getNumResults()) { 530 opOrderIndex.try_emplace(&op, orderIt); 531 orderIt += numResults; 532 } 533 auto opHash = OperationEquivalence::computeHash( 534 &op, OperationEquivalence::ignoreHashValue, 535 OperationEquivalence::ignoreHashValue, 536 OperationEquivalence::IgnoreLocations); 537 hash = llvm::hash_combine(hash, opHash); 538 } 539 } 540 541 unsigned BlockEquivalenceData::getOrderOf(Value value) const { 542 assert(value.getParentBlock() == block && "expected value of this block"); 543 544 // Arguments use the argument number as the order index. 545 if (BlockArgument arg = dyn_cast<BlockArgument>(value)) 546 return arg.getArgNumber(); 547 548 // Otherwise, the result order is offset from the parent op's order. 549 OpResult result = cast<OpResult>(value); 550 auto opOrderIt = opOrderIndex.find(result.getDefiningOp()); 551 assert(opOrderIt != opOrderIndex.end() && "expected op to have an order"); 552 return opOrderIt->second + result.getResultNumber(); 553 } 554 555 //===----------------------------------------------------------------------===// 556 // BlockMergeCluster 557 558 namespace { 559 /// This class represents a cluster of blocks to be merged together. 560 class BlockMergeCluster { 561 public: 562 BlockMergeCluster(BlockEquivalenceData &&leaderData) 563 : leaderData(std::move(leaderData)) {} 564 565 /// Attempt to add the given block to this cluster. Returns success if the 566 /// block was merged, failure otherwise. 567 LogicalResult addToCluster(BlockEquivalenceData &blockData); 568 569 /// Try to merge all of the blocks within this cluster into the leader block. 570 LogicalResult merge(RewriterBase &rewriter); 571 572 private: 573 /// The equivalence data for the leader of the cluster. 574 BlockEquivalenceData leaderData; 575 576 /// The set of blocks that can be merged into the leader. 577 llvm::SmallSetVector<Block *, 1> blocksToMerge; 578 579 /// A set of operand+index pairs that correspond to operands that need to be 580 /// replaced by arguments when the cluster gets merged. 581 std::set<std::pair<int, int>> operandsToMerge; 582 }; 583 } // namespace 584 585 LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) { 586 if (leaderData.hash != blockData.hash) 587 return failure(); 588 Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block; 589 if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes()) 590 return failure(); 591 592 // A set of operands that mismatch between the leader and the new block. 593 SmallVector<std::pair<int, int>, 8> mismatchedOperands; 594 auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end(); 595 auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end(); 596 for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) { 597 // Check that the operations are equivalent. 598 if (!OperationEquivalence::isEquivalentTo( 599 &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence, 600 /*markEquivalent=*/nullptr, 601 OperationEquivalence::Flags::IgnoreLocations)) 602 return failure(); 603 604 // Compare the operands of the two operations. If the operand is within 605 // the block, it must refer to the same operation. 606 auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands(); 607 for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) { 608 Value lhsOperand = lhsOperands[operand]; 609 Value rhsOperand = rhsOperands[operand]; 610 if (lhsOperand == rhsOperand) 611 continue; 612 // Check that the types of the operands match. 613 if (lhsOperand.getType() != rhsOperand.getType()) 614 return failure(); 615 616 // Check that these uses are both external, or both internal. 617 bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock; 618 bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock; 619 if (lhsIsInBlock != rhsIsInBlock) 620 return failure(); 621 // Let the operands differ if they are defined in a different block. These 622 // will become new arguments if the blocks get merged. 623 if (!lhsIsInBlock) { 624 625 // Check whether the operands aren't the result of an immediate 626 // predecessors terminator. In that case we are not able to use it as a 627 // successor operand when branching to the merged block as it does not 628 // dominate its producing operation. 629 auto isValidSuccessorArg = [](Block *block, Value operand) { 630 if (operand.getDefiningOp() != 631 operand.getParentBlock()->getTerminator()) 632 return true; 633 return !llvm::is_contained(block->getPredecessors(), 634 operand.getParentBlock()); 635 }; 636 637 if (!isValidSuccessorArg(leaderBlock, lhsOperand) || 638 !isValidSuccessorArg(mergeBlock, rhsOperand)) 639 return failure(); 640 641 mismatchedOperands.emplace_back(opI, operand); 642 continue; 643 } 644 645 // Otherwise, these operands must have the same logical order within the 646 // parent block. 647 if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand)) 648 return failure(); 649 } 650 651 // If the lhs or rhs has external uses, the blocks cannot be merged as the 652 // merged version of this operation will not be either the lhs or rhs 653 // alone (thus semantically incorrect), but some mix dependending on which 654 // block preceeded this. 655 // TODO allow merging of operations when one block does not dominate the 656 // other 657 if (rhsIt->isUsedOutsideOfBlock(mergeBlock) || 658 lhsIt->isUsedOutsideOfBlock(leaderBlock)) { 659 return failure(); 660 } 661 } 662 // Make sure that the block sizes are equivalent. 663 if (lhsIt != lhsE || rhsIt != rhsE) 664 return failure(); 665 666 // If we get here, the blocks are equivalent and can be merged. 667 operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end()); 668 blocksToMerge.insert(blockData.block); 669 return success(); 670 } 671 672 /// Returns true if the predecessor terminators of the given block can not have 673 /// their operands updated. 674 static bool ableToUpdatePredOperands(Block *block) { 675 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { 676 if (!isa<BranchOpInterface>((*it)->getTerminator())) 677 return false; 678 } 679 return true; 680 } 681 682 /// Prunes the redundant list of new arguments. E.g., if we are passing an 683 /// argument list like [x, y, z, x] this would return [x, y, z] and it would 684 /// update the `block` (to whom the argument are passed to) accordingly. The new 685 /// arguments are passed as arguments at the back of the block, hence we need to 686 /// know how many `numOldArguments` were before, in order to correctly replace 687 /// the new arguments in the block 688 static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments( 689 const SmallVector<SmallVector<Value, 8>, 2> &newArguments, 690 RewriterBase &rewriter, unsigned numOldArguments, Block *block) { 691 692 SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned( 693 newArguments.size(), SmallVector<Value, 8>()); 694 695 if (newArguments.empty()) 696 return newArguments; 697 698 // `newArguments` is a 2D array of size `numLists` x `numArgs` 699 unsigned numLists = newArguments.size(); 700 unsigned numArgs = newArguments[0].size(); 701 702 // Map that for each arg index contains the index that we can use in place of 703 // the original index. E.g., if we have newArgs = [x, y, z, x], we will have 704 // idxToReplacement[3] = 0 705 llvm::DenseMap<unsigned, unsigned> idxToReplacement; 706 707 // This is a useful data structure to track the first appearance of a Value 708 // on a given list of arguments 709 DenseMap<Value, unsigned> firstValueToIdx; 710 for (unsigned j = 0; j < numArgs; ++j) { 711 Value newArg = newArguments[0][j]; 712 firstValueToIdx.try_emplace(newArg, j); 713 } 714 715 // Go through the first list of arguments (list 0). 716 for (unsigned j = 0; j < numArgs; ++j) { 717 // Look back to see if there are possible redundancies in list 0. Please 718 // note that we are using a map to annotate when an argument was seen first 719 // to avoid a O(N^2) algorithm. This has the drawback that if we have two 720 // lists like: 721 // list0: [%a, %a, %a] 722 // list1: [%c, %b, %b] 723 // We cannot simplify it, because firstValueToIdx[%a] = 0, but we cannot 724 // point list1[1](==%b) or list1[2](==%b) to list1[0](==%c). However, since 725 // the number of arguments can be potentially unbounded we cannot afford a 726 // O(N^2) algorithm (to search to all the possible pairs) and we need to 727 // accept the trade-off. 728 unsigned k = firstValueToIdx[newArguments[0][j]]; 729 if (k == j) 730 continue; 731 732 bool shouldReplaceJ = true; 733 unsigned replacement = k; 734 // If a possible redundancy is found, then scan the other lists: we 735 // can prune the arguments if and only if they are redundant in every 736 // list. 737 for (unsigned i = 1; i < numLists; ++i) 738 shouldReplaceJ = 739 shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]); 740 // Save the replacement. 741 if (shouldReplaceJ) 742 idxToReplacement[j] = replacement; 743 } 744 745 // Populate the pruned argument list. 746 for (unsigned i = 0; i < numLists; ++i) 747 for (unsigned j = 0; j < numArgs; ++j) 748 if (!idxToReplacement.contains(j)) 749 newArgumentsPruned[i].push_back(newArguments[i][j]); 750 751 // Replace the block's redundant arguments. 752 SmallVector<unsigned> toErase; 753 for (auto [idx, arg] : llvm::enumerate(block->getArguments())) { 754 if (idxToReplacement.contains(idx)) { 755 Value oldArg = block->getArgument(numOldArguments + idx); 756 Value newArg = 757 block->getArgument(numOldArguments + idxToReplacement[idx]); 758 rewriter.replaceAllUsesWith(oldArg, newArg); 759 toErase.push_back(numOldArguments + idx); 760 } 761 } 762 763 // Erase the block's redundant arguments. 764 for (unsigned idxToErase : llvm::reverse(toErase)) 765 block->eraseArgument(idxToErase); 766 return newArgumentsPruned; 767 } 768 769 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) { 770 // Don't consider clusters that don't have blocks to merge. 771 if (blocksToMerge.empty()) 772 return failure(); 773 774 Block *leaderBlock = leaderData.block; 775 if (!operandsToMerge.empty()) { 776 // If the cluster has operands to merge, verify that the predecessor 777 // terminators of each of the blocks can have their successor operands 778 // updated. 779 // TODO: We could try and sub-partition this cluster if only some blocks 780 // cause the mismatch. 781 if (!ableToUpdatePredOperands(leaderBlock) || 782 !llvm::all_of(blocksToMerge, ableToUpdatePredOperands)) 783 return failure(); 784 785 // Collect the iterators for each of the blocks to merge. We will walk all 786 // of the iterators at once to avoid operand index invalidation. 787 SmallVector<Block::iterator, 2> blockIterators; 788 blockIterators.reserve(blocksToMerge.size() + 1); 789 blockIterators.push_back(leaderBlock->begin()); 790 for (Block *mergeBlock : blocksToMerge) 791 blockIterators.push_back(mergeBlock->begin()); 792 793 // Update each of the predecessor terminators with the new arguments. 794 SmallVector<SmallVector<Value, 8>, 2> newArguments( 795 1 + blocksToMerge.size(), 796 SmallVector<Value, 8>(operandsToMerge.size())); 797 unsigned curOpIndex = 0; 798 unsigned numOldArguments = leaderBlock->getNumArguments(); 799 for (const auto &it : llvm::enumerate(operandsToMerge)) { 800 unsigned nextOpOffset = it.value().first - curOpIndex; 801 curOpIndex = it.value().first; 802 803 // Process the operand for each of the block iterators. 804 for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) { 805 Block::iterator &blockIter = blockIterators[i]; 806 std::advance(blockIter, nextOpOffset); 807 auto &operand = blockIter->getOpOperand(it.value().second); 808 newArguments[i][it.index()] = operand.get(); 809 810 // Update the operand and insert an argument if this is the leader. 811 if (i == 0) { 812 Value operandVal = operand.get(); 813 operand.set(leaderBlock->addArgument(operandVal.getType(), 814 operandVal.getLoc())); 815 } 816 } 817 } 818 819 // Prune redundant arguments and update the leader block argument list 820 newArguments = pruneRedundantArguments(newArguments, rewriter, 821 numOldArguments, leaderBlock); 822 823 // Update the predecessors for each of the blocks. 824 auto updatePredecessors = [&](Block *block, unsigned clusterIndex) { 825 for (auto predIt = block->pred_begin(), predE = block->pred_end(); 826 predIt != predE; ++predIt) { 827 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator()); 828 unsigned succIndex = predIt.getSuccessorIndex(); 829 branch.getSuccessorOperands(succIndex).append( 830 newArguments[clusterIndex]); 831 } 832 }; 833 updatePredecessors(leaderBlock, /*clusterIndex=*/0); 834 for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i) 835 updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1); 836 } 837 838 // Replace all uses of the merged blocks with the leader and erase them. 839 for (Block *block : blocksToMerge) { 840 block->replaceAllUsesWith(leaderBlock); 841 rewriter.eraseBlock(block); 842 } 843 return success(); 844 } 845 846 /// Identify identical blocks within the given region and merge them, inserting 847 /// new block arguments as necessary. Returns success if any blocks were merged, 848 /// failure otherwise. 849 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, 850 Region ®ion) { 851 if (region.empty() || llvm::hasSingleElement(region)) 852 return failure(); 853 854 // Identify sets of blocks, other than the entry block, that branch to the 855 // same successors. We will use these groups to create clusters of equivalent 856 // blocks. 857 DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors; 858 for (Block &block : llvm::drop_begin(region, 1)) 859 matchingSuccessors[block.getSuccessors()].push_back(&block); 860 861 bool mergedAnyBlocks = false; 862 for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) { 863 if (blocks.size() == 1) 864 continue; 865 866 SmallVector<BlockMergeCluster, 1> clusters; 867 for (Block *block : blocks) { 868 BlockEquivalenceData data(block); 869 870 // Don't allow merging if this block has any regions. 871 // TODO: Add support for regions if necessary. 872 bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) { 873 return llvm::any_of(op.getRegions(), 874 [](Region ®ion) { return !region.empty(); }); 875 }); 876 if (hasNonEmptyRegion) 877 continue; 878 879 // Don't allow merging if this block's arguments are used outside of the 880 // original block. 881 bool argHasExternalUsers = llvm::any_of( 882 block->getArguments(), [block](mlir::BlockArgument &arg) { 883 return arg.isUsedOutsideOfBlock(block); 884 }); 885 if (argHasExternalUsers) 886 continue; 887 888 // Try to add this block to an existing cluster. 889 bool addedToCluster = false; 890 for (auto &cluster : clusters) 891 if ((addedToCluster = succeeded(cluster.addToCluster(data)))) 892 break; 893 if (!addedToCluster) 894 clusters.emplace_back(std::move(data)); 895 } 896 for (auto &cluster : clusters) 897 mergedAnyBlocks |= succeeded(cluster.merge(rewriter)); 898 } 899 900 return success(mergedAnyBlocks); 901 } 902 903 /// Identify identical blocks within the given regions and merge them, inserting 904 /// new block arguments as necessary. 905 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, 906 MutableArrayRef<Region> regions) { 907 llvm::SmallSetVector<Region *, 1> worklist; 908 for (auto ®ion : regions) 909 worklist.insert(®ion); 910 bool anyChanged = false; 911 while (!worklist.empty()) { 912 Region *region = worklist.pop_back_val(); 913 if (succeeded(mergeIdenticalBlocks(rewriter, *region))) { 914 worklist.insert(region); 915 anyChanged = true; 916 } 917 918 // Add any nested regions to the worklist. 919 for (Block &block : *region) 920 for (auto &op : block) 921 for (auto &nestedRegion : op.getRegions()) 922 worklist.insert(&nestedRegion); 923 } 924 925 return success(anyChanged); 926 } 927 928 /// If a block's argument is always the same across different invocations, then 929 /// drop the argument and use the value directly inside the block 930 static LogicalResult dropRedundantArguments(RewriterBase &rewriter, 931 Block &block) { 932 SmallVector<size_t> argsToErase; 933 934 // Go through the arguments of the block. 935 for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) { 936 bool sameArg = true; 937 Value commonValue; 938 939 // Go through the block predecessor and flag if they pass to the block 940 // different values for the same argument. 941 for (Block::pred_iterator predIt = block.pred_begin(), 942 predE = block.pred_end(); 943 predIt != predE; ++predIt) { 944 auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator()); 945 if (!branch) { 946 sameArg = false; 947 break; 948 } 949 unsigned succIndex = predIt.getSuccessorIndex(); 950 SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex); 951 auto branchOperands = succOperands.getForwardedOperands(); 952 if (!commonValue) { 953 commonValue = branchOperands[argIdx]; 954 continue; 955 } 956 if (branchOperands[argIdx] != commonValue) { 957 sameArg = false; 958 break; 959 } 960 } 961 962 // If they are passing the same value, drop the argument. 963 if (commonValue && sameArg) { 964 argsToErase.push_back(argIdx); 965 966 // Remove the argument from the block. 967 rewriter.replaceAllUsesWith(blockOperand, commonValue); 968 } 969 } 970 971 // Remove the arguments. 972 for (size_t argIdx : llvm::reverse(argsToErase)) { 973 block.eraseArgument(argIdx); 974 975 // Remove the argument from the branch ops. 976 for (auto predIt = block.pred_begin(), predE = block.pred_end(); 977 predIt != predE; ++predIt) { 978 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator()); 979 unsigned succIndex = predIt.getSuccessorIndex(); 980 SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex); 981 succOperands.erase(argIdx); 982 } 983 } 984 return success(!argsToErase.empty()); 985 } 986 987 /// This optimization drops redundant argument to blocks. I.e., if a given 988 /// argument to a block receives the same value from each of the block 989 /// predecessors, we can remove the argument from the block and use directly the 990 /// original value. This is a simple example: 991 /// 992 /// %cond = llvm.call @rand() : () -> i1 993 /// %val0 = llvm.mlir.constant(1 : i64) : i64 994 /// %val1 = llvm.mlir.constant(2 : i64) : i64 995 /// %val2 = llvm.mlir.constant(3 : i64) : i64 996 /// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2 997 /// : i64) 998 /// 999 /// ^bb1(%arg0 : i64, %arg1 : i64): 1000 /// llvm.call @foo(%arg0, %arg1) 1001 /// 1002 /// The previous IR can be rewritten as: 1003 /// %cond = llvm.call @rand() : () -> i1 1004 /// %val0 = llvm.mlir.constant(1 : i64) : i64 1005 /// %val1 = llvm.mlir.constant(2 : i64) : i64 1006 /// %val2 = llvm.mlir.constant(3 : i64) : i64 1007 /// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64) 1008 /// 1009 /// ^bb1(%arg0 : i64): 1010 /// llvm.call @foo(%val0, %arg0) 1011 /// 1012 static LogicalResult dropRedundantArguments(RewriterBase &rewriter, 1013 MutableArrayRef<Region> regions) { 1014 llvm::SmallSetVector<Region *, 1> worklist; 1015 for (Region ®ion : regions) 1016 worklist.insert(®ion); 1017 bool anyChanged = false; 1018 while (!worklist.empty()) { 1019 Region *region = worklist.pop_back_val(); 1020 1021 // Add any nested regions to the worklist. 1022 for (Block &block : *region) { 1023 anyChanged = 1024 succeeded(dropRedundantArguments(rewriter, block)) || anyChanged; 1025 1026 for (Operation &op : block) 1027 for (Region &nestedRegion : op.getRegions()) 1028 worklist.insert(&nestedRegion); 1029 } 1030 } 1031 return success(anyChanged); 1032 } 1033 1034 //===----------------------------------------------------------------------===// 1035 // Region Simplification 1036 //===----------------------------------------------------------------------===// 1037 1038 /// Run a set of structural simplifications over the given regions. This 1039 /// includes transformations like unreachable block elimination, dead argument 1040 /// elimination, as well as some other DCE. This function returns success if any 1041 /// of the regions were simplified, failure otherwise. 1042 LogicalResult mlir::simplifyRegions(RewriterBase &rewriter, 1043 MutableArrayRef<Region> regions, 1044 bool mergeBlocks) { 1045 bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions)); 1046 bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions)); 1047 bool mergedIdenticalBlocks = false; 1048 bool droppedRedundantArguments = false; 1049 if (mergeBlocks) { 1050 mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions)); 1051 droppedRedundantArguments = 1052 succeeded(dropRedundantArguments(rewriter, regions)); 1053 } 1054 return success(eliminatedBlocks || eliminatedOpsOrArgs || 1055 mergedIdenticalBlocks || droppedRedundantArguments); 1056 } 1057