xref: /llvm-project/mlir/lib/Transforms/Mem2Reg.cpp (revision 5fdda4147495c6bcfb6c100dbf32a669c3b39874)
1 //===- Mem2Reg.cpp - Promotes memory slots into values ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Mem2Reg.h"
10 #include "mlir/Analysis/DataLayoutAnalysis.h"
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Analysis/TopologicalSortUtils.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Dominance.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/RegionKindInterface.h"
17 #include "mlir/IR/Value.h"
18 #include "mlir/Interfaces/ControlFlowInterfaces.h"
19 #include "mlir/Interfaces/MemorySlotInterfaces.h"
20 #include "mlir/Transforms/Passes.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_MEM2REG
26 #include "mlir/Transforms/Passes.h.inc"
27 } // namespace mlir
28 
29 #define DEBUG_TYPE "mem2reg"
30 
31 using namespace mlir;
32 
33 /// mem2reg
34 ///
35 /// This pass turns unnecessary uses of automatically allocated memory slots
36 /// into direct Value-based operations. For example, it will simplify storing a
37 /// constant in a memory slot to immediately load it to a direct use of that
38 /// constant. In other words, given a memory slot addressed by a non-aliased
39 /// "pointer" Value, mem2reg removes all the uses of that pointer.
40 ///
41 /// Within a block, this is done by following the chain of stores and loads of
42 /// the slot and replacing the results of loads with the values previously
43 /// stored. If a load happens before any other store, a poison value is used
44 /// instead.
45 ///
46 /// Control flow can create situations where a load could be replaced by
47 /// multiple possible stores depending on the control flow path taken. As a
48 /// result, this pass must introduce new block arguments in some blocks to
49 /// accommodate for the multiple possible definitions. Each predecessor will
50 /// populate the block argument with the definition reached at its end. With
51 /// this, the value stored can be well defined at block boundaries, allowing
52 /// the propagation of replacement through blocks.
53 ///
54 /// This pass computes this transformation in four main steps. The two first
55 /// steps are performed during an analysis phase that does not mutate IR.
56 ///
57 /// The two steps of the analysis phase are the following:
58 /// - A first step computes the list of operations that transitively use the
59 /// memory slot we would like to promote. The purpose of this phase is to
60 /// identify which uses must be removed to promote the slot, either by rewiring
61 /// the user or deleting it. Naturally, direct uses of the slot must be removed.
62 /// Sometimes additional uses must also be removed: this is notably the case
63 /// when a direct user of the slot cannot rewire its use and must delete itself,
64 /// and thus must make its users no longer use it. If any of those uses cannot
65 /// be removed by their users in any way, promotion cannot continue: this is
66 /// decided at this step.
67 /// - A second step computes the list of blocks where a block argument will be
68 /// needed ("merge points") without mutating the IR. These blocks are the blocks
69 /// leading to a definition clash between two predecessors. Such blocks happen
70 /// to be the Iterated Dominance Frontier (IDF) of the set of blocks containing
71 /// a store, as they represent the point where a clear defining dominator stops
72 /// existing. Computing this information in advance allows making sure the
73 /// terminators that will forward values are capable of doing so (inability to
74 /// do so aborts promotion at this step).
75 ///
76 /// At this point, promotion is guaranteed to happen, and the mutation phase can
77 /// begin with the following steps:
78 /// - A third step computes the reaching definition of the memory slot at each
79 /// blocking user. This is the core of the mem2reg algorithm, also known as
80 /// load-store forwarding. This analyses loads and stores and propagates which
81 /// value must be stored in the slot at each blocking user.  This is achieved by
82 /// doing a depth-first walk of the dominator tree of the function. This is
83 /// sufficient because the reaching definition at the beginning of a block is
84 /// either its new block argument if it is a merge block, or the definition
85 /// reaching the end of its immediate dominator (parent in the dominator tree).
86 /// We can therefore propagate this information down the dominator tree to
87 /// proceed with renaming within blocks.
88 /// - The final fourth step uses the reaching definition to remove blocking uses
89 /// in topological order.
90 ///
91 /// For further reading, chapter three of SSA-based Compiler Design [1]
92 /// showcases SSA construction, where mem2reg is an adaptation of the same
93 /// process.
94 ///
95 /// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022),
96 ///      Springer.
97 
98 namespace {
99 
100 using BlockingUsesMap =
101     llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
102 
103 /// Information computed during promotion analysis used to perform actual
104 /// promotion.
105 struct MemorySlotPromotionInfo {
106   /// Blocks for which at least two definitions of the slot values clash.
107   SmallPtrSet<Block *, 8> mergePoints;
108   /// Contains, for each operation, which uses must be eliminated by promotion.
109   /// This is a DAG structure because if an operation must eliminate some of
110   /// its uses, it is because the defining ops of the blocking uses requested
111   /// it. The defining ops therefore must also have blocking uses or be the
112   /// starting point of the blocking uses.
113   BlockingUsesMap userToBlockingUses;
114 };
115 
116 /// Computes information for basic slot promotion. This will check that direct
117 /// slot promotion can be performed, and provide the information to execute the
118 /// promotion. This does not mutate IR.
119 class MemorySlotPromotionAnalyzer {
120 public:
121   MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance,
122                               const DataLayout &dataLayout)
123       : slot(slot), dominance(dominance), dataLayout(dataLayout) {}
124 
125   /// Computes the information for slot promotion if promotion is possible,
126   /// returns nothing otherwise.
127   std::optional<MemorySlotPromotionInfo> computeInfo();
128 
129 private:
130   /// Computes the transitive uses of the slot that block promotion. This finds
131   /// uses that would block the promotion, checks that the operation has a
132   /// solution to remove the blocking use, and potentially forwards the analysis
133   /// if the operation needs further blocking uses resolved to resolve its own
134   /// uses (typically, removing its users because it will delete itself to
135   /// resolve its own blocking uses). This will fail if one of the transitive
136   /// users cannot remove a requested use, and should prevent promotion.
137   LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
138 
139   /// Computes in which blocks the value stored in the slot is actually used,
140   /// meaning blocks leading to a load. This method uses `definingBlocks`, the
141   /// set of blocks containing a store to the slot (defining the value of the
142   /// slot).
143   SmallPtrSet<Block *, 16>
144   computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks);
145 
146   /// Computes the points in which multiple re-definitions of the slot's value
147   /// (stores) may conflict.
148   void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints);
149 
150   /// Ensures predecessors of merge points can properly provide their current
151   /// definition of the value stored in the slot to the merge point. This can
152   /// notably be an issue if the terminator used does not have the ability to
153   /// forward values through block operands.
154   bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints);
155 
156   MemorySlot slot;
157   DominanceInfo &dominance;
158   const DataLayout &dataLayout;
159 };
160 
161 using BlockIndexCache = DenseMap<Region *, DenseMap<Block *, size_t>>;
162 
163 /// The MemorySlotPromoter handles the state of promoting a memory slot. It
164 /// wraps a slot and its associated allocator. This will perform the mutation of
165 /// IR.
166 class MemorySlotPromoter {
167 public:
168   MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
169                      OpBuilder &builder, DominanceInfo &dominance,
170                      const DataLayout &dataLayout, MemorySlotPromotionInfo info,
171                      const Mem2RegStatistics &statistics,
172                      BlockIndexCache &blockIndexCache);
173 
174   /// Actually promotes the slot by mutating IR. Promoting a slot DOES
175   /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
176   /// promotion info should NOT be performed in batches.
177   /// Returns a promotable allocation op if a new allocator was created, nullopt
178   /// otherwise.
179   std::optional<PromotableAllocationOpInterface> promoteSlot();
180 
181 private:
182   /// Computes the reaching definition for all the operations that require
183   /// promotion. `reachingDef` is the value the slot should contain at the
184   /// beginning of the block. This method returns the reached definition at the
185   /// end of the block. This method must only be called at most once per block.
186   Value computeReachingDefInBlock(Block *block, Value reachingDef);
187 
188   /// Computes the reaching definition for all the operations that require
189   /// promotion. `reachingDef` corresponds to the initial value the
190   /// slot will contain before any write, typically a poison value.
191   /// This method must only be called at most once per region.
192   void computeReachingDefInRegion(Region *region, Value reachingDef);
193 
194   /// Removes the blocking uses of the slot, in topological order.
195   void removeBlockingUses();
196 
197   /// Lazily-constructed default value representing the content of the slot when
198   /// no store has been executed. This function may mutate IR.
199   Value getOrCreateDefaultValue();
200 
201   MemorySlot slot;
202   PromotableAllocationOpInterface allocator;
203   OpBuilder &builder;
204   /// Potentially non-initialized default value. Use `getOrCreateDefaultValue`
205   /// to initialize it on demand.
206   Value defaultValue;
207   /// Contains the reaching definition at this operation. Reaching definitions
208   /// are only computed for promotable memory operations with blocking uses.
209   DenseMap<PromotableMemOpInterface, Value> reachingDefs;
210   DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
211   DominanceInfo &dominance;
212   const DataLayout &dataLayout;
213   MemorySlotPromotionInfo info;
214   const Mem2RegStatistics &statistics;
215 
216   /// Shared cache of block indices of specific regions.
217   BlockIndexCache &blockIndexCache;
218 };
219 
220 } // namespace
221 
222 MemorySlotPromoter::MemorySlotPromoter(
223     MemorySlot slot, PromotableAllocationOpInterface allocator,
224     OpBuilder &builder, DominanceInfo &dominance, const DataLayout &dataLayout,
225     MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics,
226     BlockIndexCache &blockIndexCache)
227     : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
228       dataLayout(dataLayout), info(std::move(info)), statistics(statistics),
229       blockIndexCache(blockIndexCache) {
230 #ifndef NDEBUG
231   auto isResultOrNewBlockArgument = [&]() {
232     if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
233       return arg.getOwner()->getParentOp() == allocator;
234     return slot.ptr.getDefiningOp() == allocator;
235   };
236 
237   assert(isResultOrNewBlockArgument() &&
238          "a slot must be a result of the allocator or an argument of the child "
239          "regions of the allocator");
240 #endif // NDEBUG
241 }
242 
243 Value MemorySlotPromoter::getOrCreateDefaultValue() {
244   if (defaultValue)
245     return defaultValue;
246 
247   OpBuilder::InsertionGuard guard(builder);
248   builder.setInsertionPointToStart(slot.ptr.getParentBlock());
249   return defaultValue = allocator.getDefaultValue(slot, builder);
250 }
251 
252 LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
253     BlockingUsesMap &userToBlockingUses) {
254   // The promotion of an operation may require the promotion of further
255   // operations (typically, removing operations that use an operation that must
256   // delete itself). We thus need to start from the use of the slot pointer and
257   // propagate further requests through the forward slice.
258 
259   // Because this pass currently only supports analysing the parent region of
260   // the slot pointer, if a promotable memory op that needs promotion is within
261   // a graph region, the slot may only be used in a graph region and should
262   // therefore be ignored.
263   Region *slotPtrRegion = slot.ptr.getParentRegion();
264   auto slotPtrRegionOp =
265       dyn_cast<RegionKindInterface>(slotPtrRegion->getParentOp());
266   if (slotPtrRegionOp &&
267       slotPtrRegionOp.getRegionKind(slotPtrRegion->getRegionNumber()) ==
268           RegionKind::Graph)
269     return failure();
270 
271   // First insert that all immediate users of the slot pointer must no longer
272   // use it.
273   for (OpOperand &use : slot.ptr.getUses()) {
274     SmallPtrSet<OpOperand *, 4> &blockingUses =
275         userToBlockingUses[use.getOwner()];
276     blockingUses.insert(&use);
277   }
278 
279   // Then, propagate the requirements for the removal of uses. The
280   // topologically-sorted forward slice allows for all blocking uses of an
281   // operation to have been computed before it is reached. Operations are
282   // traversed in topological order of their uses, starting from the slot
283   // pointer.
284   SetVector<Operation *> forwardSlice;
285   mlir::getForwardSlice(slot.ptr, &forwardSlice);
286   for (Operation *user : forwardSlice) {
287     // If the next operation has no blocking uses, everything is fine.
288     auto it = userToBlockingUses.find(user);
289     if (it == userToBlockingUses.end())
290       continue;
291 
292     SmallPtrSet<OpOperand *, 4> &blockingUses = it->second;
293 
294     SmallVector<OpOperand *> newBlockingUses;
295     // If the operation decides it cannot deal with removing the blocking uses,
296     // promotion must fail.
297     if (auto promotable = dyn_cast<PromotableOpInterface>(user)) {
298       if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
299                                        dataLayout))
300         return failure();
301     } else if (auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
302       if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
303                                        dataLayout))
304         return failure();
305     } else {
306       // An operation that has blocking uses must be promoted. If it is not
307       // promotable, promotion must fail.
308       return failure();
309     }
310 
311     // Then, register any new blocking uses for coming operations.
312     for (OpOperand *blockingUse : newBlockingUses) {
313       assert(llvm::is_contained(user->getResults(), blockingUse->get()));
314 
315       SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
316           userToBlockingUses[blockingUse->getOwner()];
317       newUserBlockingUseSet.insert(blockingUse);
318     }
319   }
320 
321   // Because this pass currently only supports analysing the parent region of
322   // the slot pointer, if a promotable memory op that needs promotion is outside
323   // of this region, promotion must fail because it will be impossible to
324   // provide a valid `reachingDef` for it.
325   for (auto &[toPromote, _] : userToBlockingUses)
326     if (isa<PromotableMemOpInterface>(toPromote) &&
327         toPromote->getParentRegion() != slot.ptr.getParentRegion())
328       return failure();
329 
330   return success();
331 }
332 
333 SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(
334     SmallPtrSetImpl<Block *> &definingBlocks) {
335   SmallPtrSet<Block *, 16> liveIn;
336 
337   // The worklist contains blocks in which it is known that the slot value is
338   // live-in. The further blocks where this value is live-in will be inferred
339   // from these.
340   SmallVector<Block *> liveInWorkList;
341 
342   // Blocks with a load before any other store to the slot are the starting
343   // points of the analysis. The slot value is definitely live-in in those
344   // blocks.
345   SmallPtrSet<Block *, 16> visited;
346   for (Operation *user : slot.ptr.getUsers()) {
347     if (!visited.insert(user->getBlock()).second)
348       continue;
349 
350     for (Operation &op : user->getBlock()->getOperations()) {
351       if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
352         // If this operation loads the slot, it is loading from it before
353         // ever writing to it, so the value is live-in in this block.
354         if (memOp.loadsFrom(slot)) {
355           liveInWorkList.push_back(user->getBlock());
356           break;
357         }
358 
359         // If we store to the slot, further loads will see that value.
360         // Because we did not meet any load before, the value is not live-in.
361         if (memOp.storesTo(slot))
362           break;
363       }
364     }
365   }
366 
367   // The information is then propagated to the predecessors until a def site
368   // (store) is found.
369   while (!liveInWorkList.empty()) {
370     Block *liveInBlock = liveInWorkList.pop_back_val();
371 
372     if (!liveIn.insert(liveInBlock).second)
373       continue;
374 
375     // If a predecessor is a defining block, either:
376     // - It has a load before its first store, in which case it is live-in but
377     // has already been processed in the initialisation step.
378     // - It has a store before any load, in which case it is not live-in.
379     // We can thus at this stage insert to the worklist only predecessors that
380     // are not defining blocks.
381     for (Block *pred : liveInBlock->getPredecessors())
382       if (!definingBlocks.contains(pred))
383         liveInWorkList.push_back(pred);
384   }
385 
386   return liveIn;
387 }
388 
389 using IDFCalculator = llvm::IDFCalculatorBase<Block, false>;
390 void MemorySlotPromotionAnalyzer::computeMergePoints(
391     SmallPtrSetImpl<Block *> &mergePoints) {
392   if (slot.ptr.getParentRegion()->hasOneBlock())
393     return;
394 
395   IDFCalculator idfCalculator(dominance.getDomTree(slot.ptr.getParentRegion()));
396 
397   SmallPtrSet<Block *, 16> definingBlocks;
398   for (Operation *user : slot.ptr.getUsers())
399     if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
400       if (storeOp.storesTo(slot))
401         definingBlocks.insert(user->getBlock());
402 
403   idfCalculator.setDefiningBlocks(definingBlocks);
404 
405   SmallPtrSet<Block *, 16> liveIn = computeSlotLiveIn(definingBlocks);
406   idfCalculator.setLiveInBlocks(liveIn);
407 
408   SmallVector<Block *> mergePointsVec;
409   idfCalculator.calculate(mergePointsVec);
410 
411   mergePoints.insert(mergePointsVec.begin(), mergePointsVec.end());
412 }
413 
414 bool MemorySlotPromotionAnalyzer::areMergePointsUsable(
415     SmallPtrSetImpl<Block *> &mergePoints) {
416   for (Block *mergePoint : mergePoints)
417     for (Block *pred : mergePoint->getPredecessors())
418       if (!isa<BranchOpInterface>(pred->getTerminator()))
419         return false;
420 
421   return true;
422 }
423 
424 std::optional<MemorySlotPromotionInfo>
425 MemorySlotPromotionAnalyzer::computeInfo() {
426   MemorySlotPromotionInfo info;
427 
428   // First, find the set of operations that will need to be changed for the
429   // promotion to happen. These operations need to resolve some of their uses,
430   // either by rewiring them or simply deleting themselves. If any of them
431   // cannot find a way to resolve their blocking uses, we abort the promotion.
432   if (failed(computeBlockingUses(info.userToBlockingUses)))
433     return {};
434 
435   // Then, compute blocks in which two or more definitions of the allocated
436   // variable may conflict. These blocks will need a new block argument to
437   // accommodate this.
438   computeMergePoints(info.mergePoints);
439 
440   // The slot can be promoted if the block arguments to be created can
441   // actually be populated with values, which may not be possible depending
442   // on their predecessors.
443   if (!areMergePointsUsable(info.mergePoints))
444     return {};
445 
446   return info;
447 }
448 
449 Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
450                                                     Value reachingDef) {
451   SmallVector<Operation *> blockOps;
452   for (Operation &op : block->getOperations())
453     blockOps.push_back(&op);
454   for (Operation *op : blockOps) {
455     if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
456       if (info.userToBlockingUses.contains(memOp))
457         reachingDefs.insert({memOp, reachingDef});
458 
459       if (memOp.storesTo(slot)) {
460         builder.setInsertionPointAfter(memOp);
461         Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout);
462         assert(stored && "a memory operation storing to a slot must provide a "
463                          "new definition of the slot");
464         reachingDef = stored;
465         replacedValuesMap[memOp] = stored;
466       }
467     }
468   }
469 
470   return reachingDef;
471 }
472 
473 void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
474                                                     Value reachingDef) {
475   assert(reachingDef && "expected an initial reaching def to be provided");
476   if (region->hasOneBlock()) {
477     computeReachingDefInBlock(&region->front(), reachingDef);
478     return;
479   }
480 
481   struct DfsJob {
482     llvm::DomTreeNodeBase<Block> *block;
483     Value reachingDef;
484   };
485 
486   SmallVector<DfsJob> dfsStack;
487 
488   auto &domTree = dominance.getDomTree(slot.ptr.getParentRegion());
489 
490   dfsStack.emplace_back<DfsJob>(
491       {domTree.getNode(&region->front()), reachingDef});
492 
493   while (!dfsStack.empty()) {
494     DfsJob job = dfsStack.pop_back_val();
495     Block *block = job.block->getBlock();
496 
497     if (info.mergePoints.contains(block)) {
498       BlockArgument blockArgument =
499           block->addArgument(slot.elemType, slot.ptr.getLoc());
500       builder.setInsertionPointToStart(block);
501       allocator.handleBlockArgument(slot, blockArgument, builder);
502       job.reachingDef = blockArgument;
503 
504       if (statistics.newBlockArgumentAmount)
505         (*statistics.newBlockArgumentAmount)++;
506     }
507 
508     job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
509     assert(job.reachingDef);
510 
511     if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
512       for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
513         if (info.mergePoints.contains(blockOperand.get())) {
514           terminator.getSuccessorOperands(blockOperand.getOperandNumber())
515               .append(job.reachingDef);
516         }
517       }
518     }
519 
520     for (auto *child : job.block->children())
521       dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
522   }
523 }
524 
525 /// Gets or creates a block index mapping for `region`.
526 static const DenseMap<Block *, size_t> &
527 getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region) {
528   auto [it, inserted] = blockIndexCache.try_emplace(region);
529   if (!inserted)
530     return it->second;
531 
532   DenseMap<Block *, size_t> &blockIndices = it->second;
533   SetVector<Block *> topologicalOrder = getBlocksSortedByDominance(*region);
534   for (auto [index, block] : llvm::enumerate(topologicalOrder))
535     blockIndices[block] = index;
536   return blockIndices;
537 }
538 
539 /// Sorts `ops` according to dominance. Relies on the topological order of basic
540 /// blocks to get a deterministic ordering. Uses `blockIndexCache` to avoid the
541 /// potentially expensive recomputation of a block index map.
542 static void dominanceSort(SmallVector<Operation *> &ops, Region &region,
543                           BlockIndexCache &blockIndexCache) {
544   // Produce a topological block order and construct a map to lookup the indices
545   // of blocks.
546   const DenseMap<Block *, size_t> &topoBlockIndices =
547       getOrCreateBlockIndices(blockIndexCache, &region);
548 
549   // Combining the topological order of the basic blocks together with block
550   // internal operation order guarantees a deterministic, dominance respecting
551   // order.
552   llvm::sort(ops, [&](Operation *lhs, Operation *rhs) {
553     size_t lhsBlockIndex = topoBlockIndices.at(lhs->getBlock());
554     size_t rhsBlockIndex = topoBlockIndices.at(rhs->getBlock());
555     if (lhsBlockIndex == rhsBlockIndex)
556       return lhs->isBeforeInBlock(rhs);
557     return lhsBlockIndex < rhsBlockIndex;
558   });
559 }
560 
561 void MemorySlotPromoter::removeBlockingUses() {
562   llvm::SmallVector<Operation *> usersToRemoveUses(
563       llvm::make_first_range(info.userToBlockingUses));
564 
565   // Sort according to dominance.
566   dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent(),
567                 blockIndexCache);
568 
569   llvm::SmallVector<Operation *> toErase;
570   // List of all replaced values in the slot.
571   llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList;
572   // Ops to visit with the `visitReplacedValues` method.
573   llvm::SmallVector<PromotableOpInterface> toVisit;
574   for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
575     if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
576       Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
577       // If no reaching definition is known, this use is outside the reach of
578       // the slot. The default value should thus be used.
579       if (!reachingDef)
580         reachingDef = getOrCreateDefaultValue();
581 
582       builder.setInsertionPointAfter(toPromote);
583       if (toPromoteMemOp.removeBlockingUses(
584               slot, info.userToBlockingUses[toPromote], builder, reachingDef,
585               dataLayout) == DeletionKind::Delete)
586         toErase.push_back(toPromote);
587       if (toPromoteMemOp.storesTo(slot))
588         if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
589           replacedValuesList.push_back({toPromoteMemOp, replacedValue});
590       continue;
591     }
592 
593     auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
594     builder.setInsertionPointAfter(toPromote);
595     if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
596                                           builder) == DeletionKind::Delete)
597       toErase.push_back(toPromote);
598     if (toPromoteBasic.requiresReplacedValues())
599       toVisit.push_back(toPromoteBasic);
600   }
601   for (PromotableOpInterface op : toVisit) {
602     builder.setInsertionPointAfter(op);
603     op.visitReplacedValues(replacedValuesList, builder);
604   }
605 
606   for (Operation *toEraseOp : toErase)
607     toEraseOp->erase();
608 
609   assert(slot.ptr.use_empty() &&
610          "after promotion, the slot pointer should not be used anymore");
611 }
612 
613 std::optional<PromotableAllocationOpInterface>
614 MemorySlotPromoter::promoteSlot() {
615   computeReachingDefInRegion(slot.ptr.getParentRegion(),
616                              getOrCreateDefaultValue());
617 
618   // Now that reaching definitions are known, remove all users.
619   removeBlockingUses();
620 
621   // Update terminators in dead branches to forward default if they are
622   // succeeded by a merge points.
623   for (Block *mergePoint : info.mergePoints) {
624     for (BlockOperand &use : mergePoint->getUses()) {
625       auto user = cast<BranchOpInterface>(use.getOwner());
626       SuccessorOperands succOperands =
627           user.getSuccessorOperands(use.getOperandNumber());
628       assert(succOperands.size() == mergePoint->getNumArguments() ||
629              succOperands.size() + 1 == mergePoint->getNumArguments());
630       if (succOperands.size() + 1 == mergePoint->getNumArguments())
631         succOperands.append(getOrCreateDefaultValue());
632     }
633   }
634 
635   LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr
636                           << "\n");
637 
638   if (statistics.promotedAmount)
639     (*statistics.promotedAmount)++;
640 
641   return allocator.handlePromotionComplete(slot, defaultValue, builder);
642 }
643 
644 LogicalResult mlir::tryToPromoteMemorySlots(
645     ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder,
646     const DataLayout &dataLayout, DominanceInfo &dominance,
647     Mem2RegStatistics statistics) {
648   bool promotedAny = false;
649 
650   // A cache that stores deterministic block indices which are used to determine
651   // a valid operation modification order. The block index maps are computed
652   // lazily and cached to avoid expensive recomputation.
653   BlockIndexCache blockIndexCache;
654 
655   SmallVector<PromotableAllocationOpInterface> workList(allocators);
656 
657   SmallVector<PromotableAllocationOpInterface> newWorkList;
658   newWorkList.reserve(workList.size());
659   while (true) {
660     bool changesInThisRound = false;
661     for (PromotableAllocationOpInterface allocator : workList) {
662       bool changedAllocator = false;
663       for (MemorySlot slot : allocator.getPromotableSlots()) {
664         if (slot.ptr.use_empty())
665           continue;
666 
667         MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
668         std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
669         if (info) {
670           std::optional<PromotableAllocationOpInterface> newAllocator =
671               MemorySlotPromoter(slot, allocator, builder, dominance,
672                                  dataLayout, std::move(*info), statistics,
673                                  blockIndexCache)
674                   .promoteSlot();
675           changedAllocator = true;
676           // Add newly created allocators to the worklist for further
677           // processing.
678           if (newAllocator)
679             newWorkList.push_back(*newAllocator);
680 
681           // A break is required, since promoting a slot may invalidate the
682           // remaining slots of an allocator.
683           break;
684         }
685       }
686       if (!changedAllocator)
687         newWorkList.push_back(allocator);
688       changesInThisRound |= changedAllocator;
689     }
690     if (!changesInThisRound)
691       break;
692     promotedAny = true;
693 
694     // Swap the vector's backing memory and clear the entries in newWorkList
695     // afterwards. This ensures that additional heap allocations can be avoided.
696     workList.swap(newWorkList);
697     newWorkList.clear();
698   }
699 
700   return success(promotedAny);
701 }
702 
703 namespace {
704 
705 struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
706   using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
707 
708   void runOnOperation() override {
709     Operation *scopeOp = getOperation();
710 
711     Mem2RegStatistics statistics{&promotedAmount, &newBlockArgumentAmount};
712 
713     bool changed = false;
714 
715     auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
716     const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
717     auto &dominance = getAnalysis<DominanceInfo>();
718 
719     for (Region &region : scopeOp->getRegions()) {
720       if (region.getBlocks().empty())
721         continue;
722 
723       OpBuilder builder(&region.front(), region.front().begin());
724 
725       SmallVector<PromotableAllocationOpInterface> allocators;
726       // Build a list of allocators to attempt to promote the slots of.
727       region.walk([&](PromotableAllocationOpInterface allocator) {
728         allocators.emplace_back(allocator);
729       });
730 
731       // Attempt promoting as many of the slots as possible.
732       if (succeeded(tryToPromoteMemorySlots(allocators, builder, dataLayout,
733                                             dominance, statistics)))
734         changed = true;
735     }
736     if (!changed)
737       markAllAnalysesPreserved();
738   }
739 };
740 
741 } // namespace
742