1 //===- Mem2Reg.cpp - Promotes memory slots into values ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Mem2Reg.h" 10 #include "mlir/Analysis/DataLayoutAnalysis.h" 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Analysis/TopologicalSortUtils.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/Dominance.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/RegionKindInterface.h" 17 #include "mlir/IR/Value.h" 18 #include "mlir/Interfaces/ControlFlowInterfaces.h" 19 #include "mlir/Interfaces/MemorySlotInterfaces.h" 20 #include "mlir/Transforms/Passes.h" 21 #include "llvm/ADT/STLExtras.h" 22 #include "llvm/Support/GenericIteratedDominanceFrontier.h" 23 24 namespace mlir { 25 #define GEN_PASS_DEF_MEM2REG 26 #include "mlir/Transforms/Passes.h.inc" 27 } // namespace mlir 28 29 #define DEBUG_TYPE "mem2reg" 30 31 using namespace mlir; 32 33 /// mem2reg 34 /// 35 /// This pass turns unnecessary uses of automatically allocated memory slots 36 /// into direct Value-based operations. For example, it will simplify storing a 37 /// constant in a memory slot to immediately load it to a direct use of that 38 /// constant. In other words, given a memory slot addressed by a non-aliased 39 /// "pointer" Value, mem2reg removes all the uses of that pointer. 40 /// 41 /// Within a block, this is done by following the chain of stores and loads of 42 /// the slot and replacing the results of loads with the values previously 43 /// stored. If a load happens before any other store, a poison value is used 44 /// instead. 45 /// 46 /// Control flow can create situations where a load could be replaced by 47 /// multiple possible stores depending on the control flow path taken. As a 48 /// result, this pass must introduce new block arguments in some blocks to 49 /// accommodate for the multiple possible definitions. Each predecessor will 50 /// populate the block argument with the definition reached at its end. With 51 /// this, the value stored can be well defined at block boundaries, allowing 52 /// the propagation of replacement through blocks. 53 /// 54 /// This pass computes this transformation in four main steps. The two first 55 /// steps are performed during an analysis phase that does not mutate IR. 56 /// 57 /// The two steps of the analysis phase are the following: 58 /// - A first step computes the list of operations that transitively use the 59 /// memory slot we would like to promote. The purpose of this phase is to 60 /// identify which uses must be removed to promote the slot, either by rewiring 61 /// the user or deleting it. Naturally, direct uses of the slot must be removed. 62 /// Sometimes additional uses must also be removed: this is notably the case 63 /// when a direct user of the slot cannot rewire its use and must delete itself, 64 /// and thus must make its users no longer use it. If any of those uses cannot 65 /// be removed by their users in any way, promotion cannot continue: this is 66 /// decided at this step. 67 /// - A second step computes the list of blocks where a block argument will be 68 /// needed ("merge points") without mutating the IR. These blocks are the blocks 69 /// leading to a definition clash between two predecessors. Such blocks happen 70 /// to be the Iterated Dominance Frontier (IDF) of the set of blocks containing 71 /// a store, as they represent the point where a clear defining dominator stops 72 /// existing. Computing this information in advance allows making sure the 73 /// terminators that will forward values are capable of doing so (inability to 74 /// do so aborts promotion at this step). 75 /// 76 /// At this point, promotion is guaranteed to happen, and the mutation phase can 77 /// begin with the following steps: 78 /// - A third step computes the reaching definition of the memory slot at each 79 /// blocking user. This is the core of the mem2reg algorithm, also known as 80 /// load-store forwarding. This analyses loads and stores and propagates which 81 /// value must be stored in the slot at each blocking user. This is achieved by 82 /// doing a depth-first walk of the dominator tree of the function. This is 83 /// sufficient because the reaching definition at the beginning of a block is 84 /// either its new block argument if it is a merge block, or the definition 85 /// reaching the end of its immediate dominator (parent in the dominator tree). 86 /// We can therefore propagate this information down the dominator tree to 87 /// proceed with renaming within blocks. 88 /// - The final fourth step uses the reaching definition to remove blocking uses 89 /// in topological order. 90 /// 91 /// For further reading, chapter three of SSA-based Compiler Design [1] 92 /// showcases SSA construction, where mem2reg is an adaptation of the same 93 /// process. 94 /// 95 /// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022), 96 /// Springer. 97 98 namespace { 99 100 using BlockingUsesMap = 101 llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>; 102 103 /// Information computed during promotion analysis used to perform actual 104 /// promotion. 105 struct MemorySlotPromotionInfo { 106 /// Blocks for which at least two definitions of the slot values clash. 107 SmallPtrSet<Block *, 8> mergePoints; 108 /// Contains, for each operation, which uses must be eliminated by promotion. 109 /// This is a DAG structure because if an operation must eliminate some of 110 /// its uses, it is because the defining ops of the blocking uses requested 111 /// it. The defining ops therefore must also have blocking uses or be the 112 /// starting point of the blocking uses. 113 BlockingUsesMap userToBlockingUses; 114 }; 115 116 /// Computes information for basic slot promotion. This will check that direct 117 /// slot promotion can be performed, and provide the information to execute the 118 /// promotion. This does not mutate IR. 119 class MemorySlotPromotionAnalyzer { 120 public: 121 MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance, 122 const DataLayout &dataLayout) 123 : slot(slot), dominance(dominance), dataLayout(dataLayout) {} 124 125 /// Computes the information for slot promotion if promotion is possible, 126 /// returns nothing otherwise. 127 std::optional<MemorySlotPromotionInfo> computeInfo(); 128 129 private: 130 /// Computes the transitive uses of the slot that block promotion. This finds 131 /// uses that would block the promotion, checks that the operation has a 132 /// solution to remove the blocking use, and potentially forwards the analysis 133 /// if the operation needs further blocking uses resolved to resolve its own 134 /// uses (typically, removing its users because it will delete itself to 135 /// resolve its own blocking uses). This will fail if one of the transitive 136 /// users cannot remove a requested use, and should prevent promotion. 137 LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses); 138 139 /// Computes in which blocks the value stored in the slot is actually used, 140 /// meaning blocks leading to a load. This method uses `definingBlocks`, the 141 /// set of blocks containing a store to the slot (defining the value of the 142 /// slot). 143 SmallPtrSet<Block *, 16> 144 computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks); 145 146 /// Computes the points in which multiple re-definitions of the slot's value 147 /// (stores) may conflict. 148 void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints); 149 150 /// Ensures predecessors of merge points can properly provide their current 151 /// definition of the value stored in the slot to the merge point. This can 152 /// notably be an issue if the terminator used does not have the ability to 153 /// forward values through block operands. 154 bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints); 155 156 MemorySlot slot; 157 DominanceInfo &dominance; 158 const DataLayout &dataLayout; 159 }; 160 161 using BlockIndexCache = DenseMap<Region *, DenseMap<Block *, size_t>>; 162 163 /// The MemorySlotPromoter handles the state of promoting a memory slot. It 164 /// wraps a slot and its associated allocator. This will perform the mutation of 165 /// IR. 166 class MemorySlotPromoter { 167 public: 168 MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, 169 OpBuilder &builder, DominanceInfo &dominance, 170 const DataLayout &dataLayout, MemorySlotPromotionInfo info, 171 const Mem2RegStatistics &statistics, 172 BlockIndexCache &blockIndexCache); 173 174 /// Actually promotes the slot by mutating IR. Promoting a slot DOES 175 /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of 176 /// promotion info should NOT be performed in batches. 177 /// Returns a promotable allocation op if a new allocator was created, nullopt 178 /// otherwise. 179 std::optional<PromotableAllocationOpInterface> promoteSlot(); 180 181 private: 182 /// Computes the reaching definition for all the operations that require 183 /// promotion. `reachingDef` is the value the slot should contain at the 184 /// beginning of the block. This method returns the reached definition at the 185 /// end of the block. This method must only be called at most once per block. 186 Value computeReachingDefInBlock(Block *block, Value reachingDef); 187 188 /// Computes the reaching definition for all the operations that require 189 /// promotion. `reachingDef` corresponds to the initial value the 190 /// slot will contain before any write, typically a poison value. 191 /// This method must only be called at most once per region. 192 void computeReachingDefInRegion(Region *region, Value reachingDef); 193 194 /// Removes the blocking uses of the slot, in topological order. 195 void removeBlockingUses(); 196 197 /// Lazily-constructed default value representing the content of the slot when 198 /// no store has been executed. This function may mutate IR. 199 Value getOrCreateDefaultValue(); 200 201 MemorySlot slot; 202 PromotableAllocationOpInterface allocator; 203 OpBuilder &builder; 204 /// Potentially non-initialized default value. Use `getOrCreateDefaultValue` 205 /// to initialize it on demand. 206 Value defaultValue; 207 /// Contains the reaching definition at this operation. Reaching definitions 208 /// are only computed for promotable memory operations with blocking uses. 209 DenseMap<PromotableMemOpInterface, Value> reachingDefs; 210 DenseMap<PromotableMemOpInterface, Value> replacedValuesMap; 211 DominanceInfo &dominance; 212 const DataLayout &dataLayout; 213 MemorySlotPromotionInfo info; 214 const Mem2RegStatistics &statistics; 215 216 /// Shared cache of block indices of specific regions. 217 BlockIndexCache &blockIndexCache; 218 }; 219 220 } // namespace 221 222 MemorySlotPromoter::MemorySlotPromoter( 223 MemorySlot slot, PromotableAllocationOpInterface allocator, 224 OpBuilder &builder, DominanceInfo &dominance, const DataLayout &dataLayout, 225 MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics, 226 BlockIndexCache &blockIndexCache) 227 : slot(slot), allocator(allocator), builder(builder), dominance(dominance), 228 dataLayout(dataLayout), info(std::move(info)), statistics(statistics), 229 blockIndexCache(blockIndexCache) { 230 #ifndef NDEBUG 231 auto isResultOrNewBlockArgument = [&]() { 232 if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr)) 233 return arg.getOwner()->getParentOp() == allocator; 234 return slot.ptr.getDefiningOp() == allocator; 235 }; 236 237 assert(isResultOrNewBlockArgument() && 238 "a slot must be a result of the allocator or an argument of the child " 239 "regions of the allocator"); 240 #endif // NDEBUG 241 } 242 243 Value MemorySlotPromoter::getOrCreateDefaultValue() { 244 if (defaultValue) 245 return defaultValue; 246 247 OpBuilder::InsertionGuard guard(builder); 248 builder.setInsertionPointToStart(slot.ptr.getParentBlock()); 249 return defaultValue = allocator.getDefaultValue(slot, builder); 250 } 251 252 LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( 253 BlockingUsesMap &userToBlockingUses) { 254 // The promotion of an operation may require the promotion of further 255 // operations (typically, removing operations that use an operation that must 256 // delete itself). We thus need to start from the use of the slot pointer and 257 // propagate further requests through the forward slice. 258 259 // Because this pass currently only supports analysing the parent region of 260 // the slot pointer, if a promotable memory op that needs promotion is within 261 // a graph region, the slot may only be used in a graph region and should 262 // therefore be ignored. 263 Region *slotPtrRegion = slot.ptr.getParentRegion(); 264 auto slotPtrRegionOp = 265 dyn_cast<RegionKindInterface>(slotPtrRegion->getParentOp()); 266 if (slotPtrRegionOp && 267 slotPtrRegionOp.getRegionKind(slotPtrRegion->getRegionNumber()) == 268 RegionKind::Graph) 269 return failure(); 270 271 // First insert that all immediate users of the slot pointer must no longer 272 // use it. 273 for (OpOperand &use : slot.ptr.getUses()) { 274 SmallPtrSet<OpOperand *, 4> &blockingUses = 275 userToBlockingUses[use.getOwner()]; 276 blockingUses.insert(&use); 277 } 278 279 // Then, propagate the requirements for the removal of uses. The 280 // topologically-sorted forward slice allows for all blocking uses of an 281 // operation to have been computed before it is reached. Operations are 282 // traversed in topological order of their uses, starting from the slot 283 // pointer. 284 SetVector<Operation *> forwardSlice; 285 mlir::getForwardSlice(slot.ptr, &forwardSlice); 286 for (Operation *user : forwardSlice) { 287 // If the next operation has no blocking uses, everything is fine. 288 auto it = userToBlockingUses.find(user); 289 if (it == userToBlockingUses.end()) 290 continue; 291 292 SmallPtrSet<OpOperand *, 4> &blockingUses = it->second; 293 294 SmallVector<OpOperand *> newBlockingUses; 295 // If the operation decides it cannot deal with removing the blocking uses, 296 // promotion must fail. 297 if (auto promotable = dyn_cast<PromotableOpInterface>(user)) { 298 if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, 299 dataLayout)) 300 return failure(); 301 } else if (auto promotable = dyn_cast<PromotableMemOpInterface>(user)) { 302 if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses, 303 dataLayout)) 304 return failure(); 305 } else { 306 // An operation that has blocking uses must be promoted. If it is not 307 // promotable, promotion must fail. 308 return failure(); 309 } 310 311 // Then, register any new blocking uses for coming operations. 312 for (OpOperand *blockingUse : newBlockingUses) { 313 assert(llvm::is_contained(user->getResults(), blockingUse->get())); 314 315 SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet = 316 userToBlockingUses[blockingUse->getOwner()]; 317 newUserBlockingUseSet.insert(blockingUse); 318 } 319 } 320 321 // Because this pass currently only supports analysing the parent region of 322 // the slot pointer, if a promotable memory op that needs promotion is outside 323 // of this region, promotion must fail because it will be impossible to 324 // provide a valid `reachingDef` for it. 325 for (auto &[toPromote, _] : userToBlockingUses) 326 if (isa<PromotableMemOpInterface>(toPromote) && 327 toPromote->getParentRegion() != slot.ptr.getParentRegion()) 328 return failure(); 329 330 return success(); 331 } 332 333 SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn( 334 SmallPtrSetImpl<Block *> &definingBlocks) { 335 SmallPtrSet<Block *, 16> liveIn; 336 337 // The worklist contains blocks in which it is known that the slot value is 338 // live-in. The further blocks where this value is live-in will be inferred 339 // from these. 340 SmallVector<Block *> liveInWorkList; 341 342 // Blocks with a load before any other store to the slot are the starting 343 // points of the analysis. The slot value is definitely live-in in those 344 // blocks. 345 SmallPtrSet<Block *, 16> visited; 346 for (Operation *user : slot.ptr.getUsers()) { 347 if (!visited.insert(user->getBlock()).second) 348 continue; 349 350 for (Operation &op : user->getBlock()->getOperations()) { 351 if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) { 352 // If this operation loads the slot, it is loading from it before 353 // ever writing to it, so the value is live-in in this block. 354 if (memOp.loadsFrom(slot)) { 355 liveInWorkList.push_back(user->getBlock()); 356 break; 357 } 358 359 // If we store to the slot, further loads will see that value. 360 // Because we did not meet any load before, the value is not live-in. 361 if (memOp.storesTo(slot)) 362 break; 363 } 364 } 365 } 366 367 // The information is then propagated to the predecessors until a def site 368 // (store) is found. 369 while (!liveInWorkList.empty()) { 370 Block *liveInBlock = liveInWorkList.pop_back_val(); 371 372 if (!liveIn.insert(liveInBlock).second) 373 continue; 374 375 // If a predecessor is a defining block, either: 376 // - It has a load before its first store, in which case it is live-in but 377 // has already been processed in the initialisation step. 378 // - It has a store before any load, in which case it is not live-in. 379 // We can thus at this stage insert to the worklist only predecessors that 380 // are not defining blocks. 381 for (Block *pred : liveInBlock->getPredecessors()) 382 if (!definingBlocks.contains(pred)) 383 liveInWorkList.push_back(pred); 384 } 385 386 return liveIn; 387 } 388 389 using IDFCalculator = llvm::IDFCalculatorBase<Block, false>; 390 void MemorySlotPromotionAnalyzer::computeMergePoints( 391 SmallPtrSetImpl<Block *> &mergePoints) { 392 if (slot.ptr.getParentRegion()->hasOneBlock()) 393 return; 394 395 IDFCalculator idfCalculator(dominance.getDomTree(slot.ptr.getParentRegion())); 396 397 SmallPtrSet<Block *, 16> definingBlocks; 398 for (Operation *user : slot.ptr.getUsers()) 399 if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user)) 400 if (storeOp.storesTo(slot)) 401 definingBlocks.insert(user->getBlock()); 402 403 idfCalculator.setDefiningBlocks(definingBlocks); 404 405 SmallPtrSet<Block *, 16> liveIn = computeSlotLiveIn(definingBlocks); 406 idfCalculator.setLiveInBlocks(liveIn); 407 408 SmallVector<Block *> mergePointsVec; 409 idfCalculator.calculate(mergePointsVec); 410 411 mergePoints.insert(mergePointsVec.begin(), mergePointsVec.end()); 412 } 413 414 bool MemorySlotPromotionAnalyzer::areMergePointsUsable( 415 SmallPtrSetImpl<Block *> &mergePoints) { 416 for (Block *mergePoint : mergePoints) 417 for (Block *pred : mergePoint->getPredecessors()) 418 if (!isa<BranchOpInterface>(pred->getTerminator())) 419 return false; 420 421 return true; 422 } 423 424 std::optional<MemorySlotPromotionInfo> 425 MemorySlotPromotionAnalyzer::computeInfo() { 426 MemorySlotPromotionInfo info; 427 428 // First, find the set of operations that will need to be changed for the 429 // promotion to happen. These operations need to resolve some of their uses, 430 // either by rewiring them or simply deleting themselves. If any of them 431 // cannot find a way to resolve their blocking uses, we abort the promotion. 432 if (failed(computeBlockingUses(info.userToBlockingUses))) 433 return {}; 434 435 // Then, compute blocks in which two or more definitions of the allocated 436 // variable may conflict. These blocks will need a new block argument to 437 // accommodate this. 438 computeMergePoints(info.mergePoints); 439 440 // The slot can be promoted if the block arguments to be created can 441 // actually be populated with values, which may not be possible depending 442 // on their predecessors. 443 if (!areMergePointsUsable(info.mergePoints)) 444 return {}; 445 446 return info; 447 } 448 449 Value MemorySlotPromoter::computeReachingDefInBlock(Block *block, 450 Value reachingDef) { 451 SmallVector<Operation *> blockOps; 452 for (Operation &op : block->getOperations()) 453 blockOps.push_back(&op); 454 for (Operation *op : blockOps) { 455 if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) { 456 if (info.userToBlockingUses.contains(memOp)) 457 reachingDefs.insert({memOp, reachingDef}); 458 459 if (memOp.storesTo(slot)) { 460 builder.setInsertionPointAfter(memOp); 461 Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout); 462 assert(stored && "a memory operation storing to a slot must provide a " 463 "new definition of the slot"); 464 reachingDef = stored; 465 replacedValuesMap[memOp] = stored; 466 } 467 } 468 } 469 470 return reachingDef; 471 } 472 473 void MemorySlotPromoter::computeReachingDefInRegion(Region *region, 474 Value reachingDef) { 475 assert(reachingDef && "expected an initial reaching def to be provided"); 476 if (region->hasOneBlock()) { 477 computeReachingDefInBlock(®ion->front(), reachingDef); 478 return; 479 } 480 481 struct DfsJob { 482 llvm::DomTreeNodeBase<Block> *block; 483 Value reachingDef; 484 }; 485 486 SmallVector<DfsJob> dfsStack; 487 488 auto &domTree = dominance.getDomTree(slot.ptr.getParentRegion()); 489 490 dfsStack.emplace_back<DfsJob>( 491 {domTree.getNode(®ion->front()), reachingDef}); 492 493 while (!dfsStack.empty()) { 494 DfsJob job = dfsStack.pop_back_val(); 495 Block *block = job.block->getBlock(); 496 497 if (info.mergePoints.contains(block)) { 498 BlockArgument blockArgument = 499 block->addArgument(slot.elemType, slot.ptr.getLoc()); 500 builder.setInsertionPointToStart(block); 501 allocator.handleBlockArgument(slot, blockArgument, builder); 502 job.reachingDef = blockArgument; 503 504 if (statistics.newBlockArgumentAmount) 505 (*statistics.newBlockArgumentAmount)++; 506 } 507 508 job.reachingDef = computeReachingDefInBlock(block, job.reachingDef); 509 assert(job.reachingDef); 510 511 if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) { 512 for (BlockOperand &blockOperand : terminator->getBlockOperands()) { 513 if (info.mergePoints.contains(blockOperand.get())) { 514 terminator.getSuccessorOperands(blockOperand.getOperandNumber()) 515 .append(job.reachingDef); 516 } 517 } 518 } 519 520 for (auto *child : job.block->children()) 521 dfsStack.emplace_back<DfsJob>({child, job.reachingDef}); 522 } 523 } 524 525 /// Gets or creates a block index mapping for `region`. 526 static const DenseMap<Block *, size_t> & 527 getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region) { 528 auto [it, inserted] = blockIndexCache.try_emplace(region); 529 if (!inserted) 530 return it->second; 531 532 DenseMap<Block *, size_t> &blockIndices = it->second; 533 SetVector<Block *> topologicalOrder = getBlocksSortedByDominance(*region); 534 for (auto [index, block] : llvm::enumerate(topologicalOrder)) 535 blockIndices[block] = index; 536 return blockIndices; 537 } 538 539 /// Sorts `ops` according to dominance. Relies on the topological order of basic 540 /// blocks to get a deterministic ordering. Uses `blockIndexCache` to avoid the 541 /// potentially expensive recomputation of a block index map. 542 static void dominanceSort(SmallVector<Operation *> &ops, Region ®ion, 543 BlockIndexCache &blockIndexCache) { 544 // Produce a topological block order and construct a map to lookup the indices 545 // of blocks. 546 const DenseMap<Block *, size_t> &topoBlockIndices = 547 getOrCreateBlockIndices(blockIndexCache, ®ion); 548 549 // Combining the topological order of the basic blocks together with block 550 // internal operation order guarantees a deterministic, dominance respecting 551 // order. 552 llvm::sort(ops, [&](Operation *lhs, Operation *rhs) { 553 size_t lhsBlockIndex = topoBlockIndices.at(lhs->getBlock()); 554 size_t rhsBlockIndex = topoBlockIndices.at(rhs->getBlock()); 555 if (lhsBlockIndex == rhsBlockIndex) 556 return lhs->isBeforeInBlock(rhs); 557 return lhsBlockIndex < rhsBlockIndex; 558 }); 559 } 560 561 void MemorySlotPromoter::removeBlockingUses() { 562 llvm::SmallVector<Operation *> usersToRemoveUses( 563 llvm::make_first_range(info.userToBlockingUses)); 564 565 // Sort according to dominance. 566 dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent(), 567 blockIndexCache); 568 569 llvm::SmallVector<Operation *> toErase; 570 // List of all replaced values in the slot. 571 llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList; 572 // Ops to visit with the `visitReplacedValues` method. 573 llvm::SmallVector<PromotableOpInterface> toVisit; 574 for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) { 575 if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) { 576 Value reachingDef = reachingDefs.lookup(toPromoteMemOp); 577 // If no reaching definition is known, this use is outside the reach of 578 // the slot. The default value should thus be used. 579 if (!reachingDef) 580 reachingDef = getOrCreateDefaultValue(); 581 582 builder.setInsertionPointAfter(toPromote); 583 if (toPromoteMemOp.removeBlockingUses( 584 slot, info.userToBlockingUses[toPromote], builder, reachingDef, 585 dataLayout) == DeletionKind::Delete) 586 toErase.push_back(toPromote); 587 if (toPromoteMemOp.storesTo(slot)) 588 if (Value replacedValue = replacedValuesMap[toPromoteMemOp]) 589 replacedValuesList.push_back({toPromoteMemOp, replacedValue}); 590 continue; 591 } 592 593 auto toPromoteBasic = cast<PromotableOpInterface>(toPromote); 594 builder.setInsertionPointAfter(toPromote); 595 if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote], 596 builder) == DeletionKind::Delete) 597 toErase.push_back(toPromote); 598 if (toPromoteBasic.requiresReplacedValues()) 599 toVisit.push_back(toPromoteBasic); 600 } 601 for (PromotableOpInterface op : toVisit) { 602 builder.setInsertionPointAfter(op); 603 op.visitReplacedValues(replacedValuesList, builder); 604 } 605 606 for (Operation *toEraseOp : toErase) 607 toEraseOp->erase(); 608 609 assert(slot.ptr.use_empty() && 610 "after promotion, the slot pointer should not be used anymore"); 611 } 612 613 std::optional<PromotableAllocationOpInterface> 614 MemorySlotPromoter::promoteSlot() { 615 computeReachingDefInRegion(slot.ptr.getParentRegion(), 616 getOrCreateDefaultValue()); 617 618 // Now that reaching definitions are known, remove all users. 619 removeBlockingUses(); 620 621 // Update terminators in dead branches to forward default if they are 622 // succeeded by a merge points. 623 for (Block *mergePoint : info.mergePoints) { 624 for (BlockOperand &use : mergePoint->getUses()) { 625 auto user = cast<BranchOpInterface>(use.getOwner()); 626 SuccessorOperands succOperands = 627 user.getSuccessorOperands(use.getOperandNumber()); 628 assert(succOperands.size() == mergePoint->getNumArguments() || 629 succOperands.size() + 1 == mergePoint->getNumArguments()); 630 if (succOperands.size() + 1 == mergePoint->getNumArguments()) 631 succOperands.append(getOrCreateDefaultValue()); 632 } 633 } 634 635 LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr 636 << "\n"); 637 638 if (statistics.promotedAmount) 639 (*statistics.promotedAmount)++; 640 641 return allocator.handlePromotionComplete(slot, defaultValue, builder); 642 } 643 644 LogicalResult mlir::tryToPromoteMemorySlots( 645 ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder, 646 const DataLayout &dataLayout, DominanceInfo &dominance, 647 Mem2RegStatistics statistics) { 648 bool promotedAny = false; 649 650 // A cache that stores deterministic block indices which are used to determine 651 // a valid operation modification order. The block index maps are computed 652 // lazily and cached to avoid expensive recomputation. 653 BlockIndexCache blockIndexCache; 654 655 SmallVector<PromotableAllocationOpInterface> workList(allocators); 656 657 SmallVector<PromotableAllocationOpInterface> newWorkList; 658 newWorkList.reserve(workList.size()); 659 while (true) { 660 bool changesInThisRound = false; 661 for (PromotableAllocationOpInterface allocator : workList) { 662 bool changedAllocator = false; 663 for (MemorySlot slot : allocator.getPromotableSlots()) { 664 if (slot.ptr.use_empty()) 665 continue; 666 667 MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout); 668 std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo(); 669 if (info) { 670 std::optional<PromotableAllocationOpInterface> newAllocator = 671 MemorySlotPromoter(slot, allocator, builder, dominance, 672 dataLayout, std::move(*info), statistics, 673 blockIndexCache) 674 .promoteSlot(); 675 changedAllocator = true; 676 // Add newly created allocators to the worklist for further 677 // processing. 678 if (newAllocator) 679 newWorkList.push_back(*newAllocator); 680 681 // A break is required, since promoting a slot may invalidate the 682 // remaining slots of an allocator. 683 break; 684 } 685 } 686 if (!changedAllocator) 687 newWorkList.push_back(allocator); 688 changesInThisRound |= changedAllocator; 689 } 690 if (!changesInThisRound) 691 break; 692 promotedAny = true; 693 694 // Swap the vector's backing memory and clear the entries in newWorkList 695 // afterwards. This ensures that additional heap allocations can be avoided. 696 workList.swap(newWorkList); 697 newWorkList.clear(); 698 } 699 700 return success(promotedAny); 701 } 702 703 namespace { 704 705 struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> { 706 using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase; 707 708 void runOnOperation() override { 709 Operation *scopeOp = getOperation(); 710 711 Mem2RegStatistics statistics{&promotedAmount, &newBlockArgumentAmount}; 712 713 bool changed = false; 714 715 auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 716 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp); 717 auto &dominance = getAnalysis<DominanceInfo>(); 718 719 for (Region ®ion : scopeOp->getRegions()) { 720 if (region.getBlocks().empty()) 721 continue; 722 723 OpBuilder builder(®ion.front(), region.front().begin()); 724 725 SmallVector<PromotableAllocationOpInterface> allocators; 726 // Build a list of allocators to attempt to promote the slots of. 727 region.walk([&](PromotableAllocationOpInterface allocator) { 728 allocators.emplace_back(allocator); 729 }); 730 731 // Attempt promoting as many of the slots as possible. 732 if (succeeded(tryToPromoteMemorySlots(allocators, builder, dataLayout, 733 dominance, statistics))) 734 changed = true; 735 } 736 if (!changed) 737 markAllAnalysesPreserved(); 738 } 739 }; 740 741 } // namespace 742